@@ -79,13 +79,11 @@ def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkCo
7979 bufferA = ctxt .lookup (name = parseDict ['A' ])
8080 bufferB = ctxt .lookup (name = parseDict ['B' ])
8181
82- tensorsShapeLen = len (bufferA .shape )
83-
8482 # ===== EXTRACT TENSOR DIMS AS VARS =====
8583 ASecondDimVar = tilerModel .getTensorDimVar (tensorName = bufferA .name ,
86- dimIdx = (tensorsShapeLen - 1 ) - parseDict ['transA' ])
84+ dimIdx = (len ( bufferA . shape ) - 1 ) - parseDict ['transA' ])
8785 BFirstDimVar = tilerModel .getTensorDimVar (tensorName = bufferB .name ,
88- dimIdx = (tensorsShapeLen - 2 ) + parseDict ['transB' ])
86+ dimIdx = (len ( bufferB . shape ) - 2 ) + parseDict ['transB' ])
8987
9088 # ===== ADD CONSTRAINTS =====
9189 # VIC: We don't want to deal with intermediate results between kernel calls
@@ -111,11 +109,15 @@ def serializeTilingSolution(
111109 buffB = ctxt .lookup (operatorRepresentation ['B' ])
112110 buffOut = ctxt .lookup (operatorRepresentation ['data_out' ])
113111
112+ transA = operatorRepresentation ['transA' ]
113+ transB = operatorRepresentation ['transB' ]
114+
114115 tensorsShapeLenA = len (buffA .shape )
115116 tensorsShapeLenB = len (buffB .shape )
116117 tensorsShapeOutput = len (buffOut .shape )
117118
118- NSize = buffA .shape [- 1 ]
119+ # NSize depends on transA: if transA=0, N is last dim; if transA=1, N is second-to-last
120+ NSize = buffA .shape [- 1 ] if transA == 0 else buffA .shape [- 2 ]
119121 NOffset = 0
120122
121123 # Prepare input cubes lists
@@ -148,9 +150,13 @@ def serializeTilingSolution(
148150 replacements ["batch" ].append (BatchSize )
149151
150152 # ===== Compute A cube information =====
151- # Matrix offsets and shape
152- AMatrixOffsets = (MOffset , NOffset )
153- AMatrixShape = (MSize , NSize )
153+ # Matrix offsets and shape (swap based on transA)
154+ if transA == 0 :
155+ AMatrixOffsets = (MOffset , NOffset )
156+ AMatrixShape = (MSize , NSize )
157+ else :
158+ AMatrixOffsets = (NOffset , MOffset )
159+ AMatrixShape = (NSize , MSize )
154160
155161 # Batch offset and shape (with broadcasting handling)
156162 ABatchOffsets = list ()
@@ -170,9 +176,13 @@ def serializeTilingSolution(
170176 inputACubes .append (ACube )
171177
172178 # ===== Compute B cube information =====
173- # Matrix offsets and shape
174- BMatrixOffsets = (NOffset , OOffset )
175- BMatrixShape = (NSize , OSize )
179+ # Matrix offsets and shape (swap based on transB)
180+ if transB == 0 :
181+ BMatrixOffsets = (NOffset , OOffset )
182+ BMatrixShape = (NSize , OSize )
183+ else :
184+ BMatrixOffsets = (OOffset , NOffset )
185+ BMatrixShape = (OSize , NSize )
176186
177187 # Batch offset and shape (with broadcasting handling)
178188 BBatchOffsets = list ()
@@ -206,7 +216,8 @@ def serializeTilingSolution(
206216 }
207217
208218 # Update load schedule lists
209- for a , b in zip (inputACubes , inputBCubes ):
219+ # *With strict=True to fail fast if different list lenghts
220+ for a , b in zip (inputACubes , inputBCubes , strict = True ):
210221 inputLoadSchedule .append ({"A" : a , "B" : b })
211222
212223 for out in outputCubes :
0 commit comments