@@ -1605,27 +1605,40 @@ def parseNodeCtxt(self,
16051605 node .inputs .append (zeroTensor )
16061606 self .operatorRepresentation ['C' ] = f'{ node .name } _C_Tensor'
16071607
1608+ buffA = ctxt .lookup (node .inputs [0 ].name )
1609+ assert isinstance (buffA , VariableBuffer )
1610+ buffB = ctxt .lookup (node .inputs [1 ].name )
1611+ assert isinstance (buffB , VariableBuffer )
1612+ buffOut = ctxt .lookup (node .outputs [0 ].name )
1613+ assert isinstance (buffOut , VariableBuffer )
1614+
16081615 # Store the input and output shapes in the operator representation
1609- self .operatorRepresentation ['size' ] = np .prod (ctxt .lookup (node .inputs [0 ].name ).shape )
1610- self .operatorRepresentation ['A_shape' ] = ctxt .lookup (node .inputs [0 ].name ).shape
1611- self .operatorRepresentation ['B_shape' ] = ctxt .lookup (node .inputs [1 ].name ).shape
1612- self .operatorRepresentation ['data_out_shape' ] = ctxt .lookup (node .outputs [0 ].name ).shape
1616+ self .operatorRepresentation ['size' ] = np .prod (buffA .shape )
1617+ self .operatorRepresentation ['A_shape' ] = buffA .shape
1618+ self .operatorRepresentation ['B_shape' ] = buffB .shape
1619+ self .operatorRepresentation ['data_out_shape' ] = buffOut .shape
1620+
1621+ if self .operatorRepresentation ['transA' ]:
1622+ N_A , M = buffA .shape [- 2 :]
1623+ else :
1624+ M , N_A = buffA .shape [- 2 :]
1625+
1626+ if self .operatorRepresentation ['transB' ]:
1627+ O , N_B = buffB .shape [- 2 :]
1628+ else :
1629+ N_B , O = buffB .shape [- 2 :]
16131630
16141631 # Store the matrix dimensions in the operator representation
1615- self .operatorRepresentation ['M' ] = ctxt .lookup (
1616- node .inputs [0 ].name ).shape [(- 2 + self .operatorRepresentation ['transA' ])]
1617- self .operatorRepresentation ['N' ] = ctxt .lookup (
1618- node .inputs [0 ].name ).shape [(- 1 - self .operatorRepresentation ['transA' ])]
1619- self .operatorRepresentation ['O' ] = ctxt .lookup (
1620- node .inputs [1 ].name ).shape [(- 1 - self .operatorRepresentation ['transB' ])]
1632+ self .operatorRepresentation ['M' ] = M
1633+ self .operatorRepresentation ['N' ] = N_A
1634+ self .operatorRepresentation ['O' ] = O
16211635
16221636 # SCHEREMO: Assert that reduction dimension is the same on both matrices
1623- ret = ret and (self .operatorRepresentation ['N' ] == ctxt .lookup (
1624- node .inputs [1 ].name ).shape [- 2 + self .operatorRepresentation ['transB' ]])
1637+ ret = ret and N_A == N_B
16251638
16261639 # Check if the batch dimensions are compatible
1627- self .operatorRepresentation ['batch_A' ] = np .prod (ctxt . lookup ( node . inputs [ 0 ]. name ) .shape [:- 2 ])
1628- self .operatorRepresentation ['batch_B' ] = np .prod (ctxt . lookup ( node . inputs [ 1 ]. name ) .shape [:- 2 ])
1640+ self .operatorRepresentation ['batch_A' ] = np .prod (buffA .shape [:- 2 ])
1641+ self .operatorRepresentation ['batch_B' ] = np .prod (buffB .shape [:- 2 ])
16291642
16301643 self .operatorRepresentation ['batch' ] = max (self .operatorRepresentation ['batch_A' ],
16311644 self .operatorRepresentation ['batch_B' ])
@@ -1637,10 +1650,10 @@ def parseNodeCtxt(self,
16371650 ), "Incompatible dimensions for input matrices. Broadcasting not yet supported for dimensions larger than 1 on one of the inputs, or equal dimensions between the 2."
16381651
16391652 # Create flags for same dimension between each input matrix and the final batch dimension
1640- self .operatorRepresentation ['A_batched' ] = (self . operatorRepresentation [ 'batch' ] == np . prod (
1641- ctxt . lookup ( node . inputs [ 0 ]. name ). shape [: - 2 ]) )
1653+ self .operatorRepresentation ['A_batched' ] = (
1654+ self . operatorRepresentation [ 'batch' ] == self . operatorRepresentation [ 'batch_A' ] )
16421655 self .operatorRepresentation ['W_batched' ] = self .operatorRepresentation ['B_batched' ] = (
1643- self .operatorRepresentation ['batch' ] == np . prod ( ctxt . lookup ( node . inputs [ 1 ]. name ). shape [: - 2 ]) )
1656+ self .operatorRepresentation ['batch' ] == self . operatorRepresentation [ 'batch_B' ] )
16441657
16451658 return ctxt , ret
16461659
0 commit comments