@@ -2046,44 +2046,38 @@ def broadcast(self, ctxt: NetworkContext, default_channels_first: bool = True) -
20462046 inputShapes = [ctxt .lookup (node .name ).shape for node in self .node .inputs ]
20472047 outputShapes = [ctxt .lookup (node .name ).shape for node in self .node .outputs ]
20482048
2049- if not "channels_first" in self .mapper .parser .operatorRepresentation :
2050- channels_first = default_channels_first
2051- else :
2052- channels_first = self .mapper .parser .operatorRepresentation ['channels_first' ]
2049+ opRepr = self .mapper .parser .operatorRepresentation
2050+ channels_first = opRepr .get ("channels_first" , default_channels_first )
2051+ newInputShapes , newOutputShapes = self .computeShapes (inputShapes , outputShapes , opRepr , channels_first )
20532052
2054- newInputShapes , newOutputShapes = self .computeShapes (inputShapes , outputShapes ,
2055- self .mapper .parser .operatorRepresentation , channels_first )
2053+ for tensor , shape in zip (self .node .inputs + self .node .outputs , newInputShapes + newOutputShapes ):
2054+ buffer = ctxt .lookup (tensor .name )
2055+ assert isinstance (buffer , VariableBuffer )
20562056
2057- for node , newShape in zip (self .node .inputs + self .node .outputs , newInputShapes + newOutputShapes ):
2058- if ctxt .is_local (node .name ):
2059- ctxt .localObjects [node .name ].shape = newShape
2057+ if ctxt .is_local (tensor .name ):
2058+ buffer .shape = shape
20602059 # Update shape of tensors in onnx graph
2061- node .shape = newShape
2060+ tensor .shape = shape
20622061
20632062 # WIESEP: It is possible that the type was not yet set, so we assume some default type
20642063 # At this state, we assume that all local buffers are float32 type inference is not yet done.
2065- if node .dtype is None :
2066- node .dtype = np .float32
2064+ if tensor .dtype is None :
2065+ tensor .dtype = np .float32
20672066
2068- elif ctxt .is_global (node .name ):
2069- ctxt . globalObjects [ node . name ]. shape = newShape
2070- if isinstance (ctxt . globalObjects [ node . name ] , ConstantBuffer ):
2067+ elif ctxt .is_global (tensor .name ):
2068+ buffer . shape = shape
2069+ if isinstance (buffer , ConstantBuffer ):
20712070
20722071 # If the number of elements is equal, reshape
2073- if np .prod (ctxt . globalObjects [ node . name ]. values .shape ) == np .prod (newShape ):
2074- ctxt . globalObjects [ node . name ]. values .reshape (newShape )
2072+ if np .prod (buffer . values .shape ) == np .prod (shape ):
2073+ buffer . values .reshape (shape )
20752074 # The number of elements SHOULD be lower, and we broadcast
20762075 else :
20772076 try :
2078- ctxt .globalObjects [node .name ].values = np .broadcast_to (ctxt .globalObjects [node .name ].values ,
2079- newShape )
2080- except :
2081- raise RuntimeError (
2082- f"Could not broadcast { node .name } from { ctxt .globalObjects [node .name ].values .shape } to { newShape } !"
2083- )
2084-
2085- else :
2086- raise KeyError (f'Expected node { node .name } to be in context!' )
2077+ buffer .values = np .broadcast_to (buffer .values , shape )
2078+ except ValueError as e :
2079+ raise ValueError (
2080+ f"Could not broadcast tensor { tensor .name } of node { self .node .name } ." ) from e
20872081
20882082 return ctxt
20892083
0 commit comments