Skip to content

Commit b10c0a5

Browse files
committed
Broadcast refactor
1 parent 54e5b57 commit b10c0a5

File tree

1 file changed

+20
-26
lines changed

1 file changed

+20
-26
lines changed

Deeploy/DeeployTypes.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,44 +2046,38 @@ def broadcast(self, ctxt: NetworkContext, default_channels_first: bool = True) -
20462046
inputShapes = [ctxt.lookup(node.name).shape for node in self.node.inputs]
20472047
outputShapes = [ctxt.lookup(node.name).shape for node in self.node.outputs]
20482048

2049-
if not "channels_first" in self.mapper.parser.operatorRepresentation:
2050-
channels_first = default_channels_first
2051-
else:
2052-
channels_first = self.mapper.parser.operatorRepresentation['channels_first']
2049+
opRepr = self.mapper.parser.operatorRepresentation
2050+
channels_first = opRepr.get("channels_first", default_channels_first)
2051+
newInputShapes, newOutputShapes = self.computeShapes(inputShapes, outputShapes, opRepr, channels_first)
20532052

2054-
newInputShapes, newOutputShapes = self.computeShapes(inputShapes, outputShapes,
2055-
self.mapper.parser.operatorRepresentation, channels_first)
2053+
for tensor, shape in zip(self.node.inputs + self.node.outputs, newInputShapes + newOutputShapes):
2054+
buffer = ctxt.lookup(tensor.name)
2055+
assert isinstance(buffer, VariableBuffer)
20562056

2057-
for node, newShape in zip(self.node.inputs + self.node.outputs, newInputShapes + newOutputShapes):
2058-
if ctxt.is_local(node.name):
2059-
ctxt.localObjects[node.name].shape = newShape
2057+
if ctxt.is_local(tensor.name):
2058+
buffer.shape = shape
20602059
# Update shape of tensors in onnx graph
2061-
node.shape = newShape
2060+
tensor.shape = shape
20622061

20632062
# WIESEP: It is possible that the type was not yet set, so we assume some default type
20642063
# At this state, we assume that all local buffers are float32 type inference is not yet done.
2065-
if node.dtype is None:
2066-
node.dtype = np.float32
2064+
if tensor.dtype is None:
2065+
tensor.dtype = np.float32
20672066

2068-
elif ctxt.is_global(node.name):
2069-
ctxt.globalObjects[node.name].shape = newShape
2070-
if isinstance(ctxt.globalObjects[node.name], ConstantBuffer):
2067+
elif ctxt.is_global(tensor.name):
2068+
buffer.shape = shape
2069+
if isinstance(buffer, ConstantBuffer):
20712070

20722071
# If the number of elements is equal, reshape
2073-
if np.prod(ctxt.globalObjects[node.name].values.shape) == np.prod(newShape):
2074-
ctxt.globalObjects[node.name].values.reshape(newShape)
2072+
if np.prod(buffer.values.shape) == np.prod(shape):
2073+
buffer.values.reshape(shape)
20752074
# The number of elements SHOULD be lower, and we broadcast
20762075
else:
20772076
try:
2078-
ctxt.globalObjects[node.name].values = np.broadcast_to(ctxt.globalObjects[node.name].values,
2079-
newShape)
2080-
except:
2081-
raise RuntimeError(
2082-
f"Could not broadcast {node.name} from {ctxt.globalObjects[node.name].values.shape} to {newShape}!"
2083-
)
2084-
2085-
else:
2086-
raise KeyError(f'Expected node {node.name} to be in context!')
2077+
buffer.values = np.broadcast_to(buffer.values, shape)
2078+
except ValueError as e:
2079+
raise ValueError(
2080+
f"Could not broadcast tensor {tensor.name} of node {self.node.name}.") from e
20872081

20882082
return ctxt
20892083

0 commit comments

Comments
 (0)