Skip to content

Commit 6c91310

Browse files
committed
Fix transA and transB being treated like ints
1 parent 3c18467 commit 6c91310

File tree

9 files changed

+50
-43
lines changed

9 files changed

+50
-43
lines changed

Deeploy/Targets/Generic/Parsers.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,27 +1605,40 @@ def parseNodeCtxt(self,
16051605
node.inputs.append(zeroTensor)
16061606
self.operatorRepresentation['C'] = f'{node.name}_C_Tensor'
16071607

1608+
buffA = ctxt.lookup(node.inputs[0].name)
1609+
assert isinstance(buffA, VariableBuffer)
1610+
buffB = ctxt.lookup(node.inputs[1].name)
1611+
assert isinstance(buffB, VariableBuffer)
1612+
buffOut = ctxt.lookup(node.outputs[0].name)
1613+
assert isinstance(buffOut, VariableBuffer)
1614+
16081615
# Store the input and output shapes in the operator representation
1609-
self.operatorRepresentation['size'] = np.prod(ctxt.lookup(node.inputs[0].name).shape)
1610-
self.operatorRepresentation['A_shape'] = ctxt.lookup(node.inputs[0].name).shape
1611-
self.operatorRepresentation['B_shape'] = ctxt.lookup(node.inputs[1].name).shape
1612-
self.operatorRepresentation['data_out_shape'] = ctxt.lookup(node.outputs[0].name).shape
1616+
self.operatorRepresentation['size'] = np.prod(buffA.shape)
1617+
self.operatorRepresentation['A_shape'] = buffA.shape
1618+
self.operatorRepresentation['B_shape'] = buffB.shape
1619+
self.operatorRepresentation['data_out_shape'] = buffOut.shape
1620+
1621+
if self.operatorRepresentation['transA']:
1622+
N_A, M = buffA.shape[-2:]
1623+
else:
1624+
M, N_A = buffA.shape[-2:]
1625+
1626+
if self.operatorRepresentation['transB']:
1627+
O, N_B = buffB.shape[-2:]
1628+
else:
1629+
N_B, O = buffB.shape[-2:]
16131630

16141631
# Store the matrix dimensions in the operator representation
1615-
self.operatorRepresentation['M'] = ctxt.lookup(
1616-
node.inputs[0].name).shape[(-2 + self.operatorRepresentation['transA'])]
1617-
self.operatorRepresentation['N'] = ctxt.lookup(
1618-
node.inputs[0].name).shape[(-1 - self.operatorRepresentation['transA'])]
1619-
self.operatorRepresentation['O'] = ctxt.lookup(
1620-
node.inputs[1].name).shape[(-1 - self.operatorRepresentation['transB'])]
1632+
self.operatorRepresentation['M'] = M
1633+
self.operatorRepresentation['N'] = N_A
1634+
self.operatorRepresentation['O'] = O
16211635

16221636
# SCHEREMO: Assert that reduction dimension is the same on both matrices
1623-
ret = ret and (self.operatorRepresentation['N'] == ctxt.lookup(
1624-
node.inputs[1].name).shape[-2 + self.operatorRepresentation['transB']])
1637+
ret = ret and N_A == N_B
16251638

16261639
# Check if the batch dimensions are compatible
1627-
self.operatorRepresentation['batch_A'] = np.prod(ctxt.lookup(node.inputs[0].name).shape[:-2])
1628-
self.operatorRepresentation['batch_B'] = np.prod(ctxt.lookup(node.inputs[1].name).shape[:-2])
1640+
self.operatorRepresentation['batch_A'] = np.prod(buffA.shape[:-2])
1641+
self.operatorRepresentation['batch_B'] = np.prod(buffB.shape[:-2])
16291642

16301643
self.operatorRepresentation['batch'] = max(self.operatorRepresentation['batch_A'],
16311644
self.operatorRepresentation['batch_B'])
@@ -1637,10 +1650,10 @@ def parseNodeCtxt(self,
16371650
), "Incompatible dimensions for input matrices. Broadcasting not yet supported for dimensions larger than 1 on one of the inputs, or equal dimensions between the 2."
16381651

16391652
# Create flags for same dimension between each input matrix and the final batch dimension
1640-
self.operatorRepresentation['A_batched'] = (self.operatorRepresentation['batch'] == np.prod(
1641-
ctxt.lookup(node.inputs[0].name).shape[:-2]))
1653+
self.operatorRepresentation['A_batched'] = (
1654+
self.operatorRepresentation['batch'] == self.operatorRepresentation['batch_A'])
16421655
self.operatorRepresentation['W_batched'] = self.operatorRepresentation['B_batched'] = (
1643-
self.operatorRepresentation['batch'] == np.prod(ctxt.lookup(node.inputs[1].name).shape[:-2]))
1656+
self.operatorRepresentation['batch'] == self.operatorRepresentation['batch_B'])
16441657

16451658
return ctxt, ret
16461659

Deeploy/Targets/Generic/Templates/FloatGemmTemplate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
${M},
2222
${N},
2323
${O},
24-
${transA},
25-
${transB}
24+
${int(transA)},
25+
${int(transB)}
2626
);
2727
2828
% if A_batched:

Deeploy/Targets/Generic/Templates/GemmTemplate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def alignToContext(self, ctxt: NetworkContext,
5656
${O},
5757
${alpha},
5858
${beta},
59-
${transA},
60-
${transB},
59+
${int(transA)},
60+
${int(transB)},
6161
${A_offset},
6262
${B_offset},
6363
${C_offset},

Deeploy/Targets/Generic/TypeCheckers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,8 @@ def __init__(self, input_types: Sequence[Type[Pointer]], output_types: Sequence[
185185

186186
def _inferNumLevels(self, inputs: List[VariableBuffer],
187187
operatorRepresentation: OperatorRepresentation) -> List[int]:
188-
return [
189-
2**((self.input_types[0].referencedType.typeWidth) * 2) *
190-
inputs[0].shape[-1 - operatorRepresentation['transA']]
191-
]
188+
O = inputs[0].shape[-1] if not operatorRepresentation['transA'] else inputs[0].shape[-2]
189+
return [2**((self.input_types[0].referencedType.typeWidth) * 2) * O]
192190

193191
def _inferSignedness(self, inputs: List[VariableBuffer],
194192
operatorRepresentation: OperatorRepresentation) -> List[bool]:

Deeploy/Targets/MemPool/Templates/GemmTemplate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def hoistTransientBuffers(self, ctxt: NetworkContext,
127127
${O},
128128
${alpha},
129129
${beta},
130-
${transA},
131-
${transB},
130+
${int(transA)},
131+
${int(transB)},
132132
${A_offset},
133133
${B_offset},
134134
${C_offset},

Deeploy/Targets/MemPool/Templates/RQGemmTemplate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def hoistTransientBuffers(self, ctxt: NetworkContext,
145145
${O},
146146
${alpha},
147147
${beta},
148-
${transA},
149-
${transB},
148+
${int(transA)},
149+
${int(transB)},
150150
${mul},
151151
${add},
152152
${log2Dstring},
@@ -170,8 +170,8 @@ def hoistTransientBuffers(self, ctxt: NetworkContext,
170170
${O},
171171
${alpha},
172172
${beta},
173-
${transA},
174-
${transB},
173+
${int(transA)},
174+
${int(transB)},
175175
${mul},
176176
${add},
177177
${log2Dstring},

Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
${M},
2121
${N},
2222
${O},
23-
${transA},
24-
${transB}
23+
${int(transA)},
24+
${int(transB)}
2525
);
2626
2727
ref_${data_out}_${A} += ${M} * ${N};

Deeploy/Targets/PULPOpen/TileConstraints/MatMulTileConstraint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
3232
tensorsShapeLen = len(bufferA.shape)
3333

3434
AFirstDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name,
35-
dimIdx = (tensorsShapeLen - 2) + parseDict['transA'])
35+
dimIdx = (tensorsShapeLen - 2) + int(parseDict['transA']))
3636
ASecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name,
37-
dimIdx = (tensorsShapeLen - 1) - parseDict['transA'])
37+
dimIdx = (tensorsShapeLen - 1) - int(parseDict['transA']))
3838
BFirstDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name,
39-
dimIdx = (tensorsShapeLen - 2) + parseDict['transB'])
39+
dimIdx = (tensorsShapeLen - 2) + int(parseDict['transB']))
4040
BSecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name,
41-
dimIdx = (tensorsShapeLen - 1) - parseDict['transB'])
41+
dimIdx = (tensorsShapeLen - 1) - int(parseDict['transB']))
4242
outputFirstDimVar = tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = (tensorsShapeLen - 2))
4343
outputSecondDimVar = tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = (tensorsShapeLen - 1))
4444

Deeploy/Targets/Snitch/Parsers.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ def parseNode(self, node: gs.Node) -> bool:
1818
if not ret:
1919
return False
2020

21-
if not all([
22-
self.operatorRepresentation['transA'] == 0,
23-
]):
21+
if self.operatorRepresentation['transA']:
2422
return False
2523

2624
return True
@@ -50,9 +48,7 @@ def parseNode(self, node: gs.Node) -> bool:
5048
if not ret:
5149
return False
5250

53-
if not all([
54-
self.operatorRepresentation['transA'] == 0,
55-
]):
51+
if self.operatorRepresentation['transA']:
5652
return False
5753

5854
return True

0 commit comments

Comments
 (0)