@@ -175,43 +175,73 @@ class FuseReshape {
175175
176176 // Create a MakeTensorPtrOp yielding a 2-dim block pointer.
177177 auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType ());
178- ArrayRef<int64_t > origShape =
178+ [[maybe_unused]] ArrayRef<int64_t > resShape =
179179 cast<RankedTensorType>(ptrType.getPointeeType ()).getShape ();
180- assert (origShape[0 ] && " First shape extent is not one" );
180+ assert (resShape[0 ] == 1 && " Result shape should have extent equal to 1 in "
181+ " the outermost dimension" );
181182
182183 auto tensorType = cast<RankedTensorType>(reshapeOp.getType ());
183184 auto newPtrType =
184185 tt::PointerType::get (tensorType, ptrType.getAddressSpace ());
185186
186- unsigned innermostDimIdx = 0 ;
187+ // Compute the index of the innermost dimension.
187188 ArrayRef<int > order = makeTensorPtrOp.getOrder ();
188- for (int i : order) {
189- if (i == 0 )
189+ assert (order.size () == 3 && order[0 ] == 2 && " Invalid order" );
190+
191+ unsigned innermostDimIdx = 0 ;
192+ for (int elem : makeTensorPtrOp.getOrder ()) {
193+ if (elem == 0 )
190194 break ;
191195 ++innermostDimIdx;
192196 }
193197
194198 OpBuilder builder (makeTensorPtrOp);
195199 Location loc = makeTensorPtrOp.getLoc ();
196- Value firstShape = makeTensorPtrOp.getShape ().front ();
197- Value firstStride = makeTensorPtrOp.getStrides ().front ();
198- Value firstOffset = makeTensorPtrOp.getOffsets ().front ();
200+ OperandRange shapes = makeTensorPtrOp.getShape ();
201+ OperandRange strides = makeTensorPtrOp.getStrides ();
202+ OperandRange offsets = makeTensorPtrOp.getOffsets ();
203+
204+ #if 0
205+ // order=2,1,0 --> idx = 2 (row major) --> idx we want = 1
206+ // order=2,0,1 --> idx = 1 (column major) --> idx we want == 0
199207
200208 SmallVector<Value> newShape(makeTensorPtrOp.getShape().drop_front());
201209 newShape[innermostDimIdx - 1] = builder.create<arith::AddIOp>(
202- loc, builder.create <arith::MulIOp>(loc, firstStride, firstShape ),
210+ loc, builder.create<arith::MulIOp>(loc, strides[0], shapes[0] ),
203211 newShape[innermostDimIdx - 1]);
204212 SmallVector<Value> newStrides(makeTensorPtrOp.getStrides().drop_front());
205213 SmallVector<Value> newOffsets(makeTensorPtrOp.getOffsets().drop_front());
206214 newOffsets[innermostDimIdx - 1] = builder.create<arith::AddIOp>(
207215 loc,
208216 builder.create<arith::MulIOp>(
209217 loc,
210- builder.create <arith::TruncIOp>(loc, firstOffset .getType (),
211- firstStride ),
212- firstOffset ),
218+ builder.create<arith::TruncIOp>(loc, offsets[0] .getType(),
219+ strides[0] ),
220+ offsets[0] ),
213221 newOffsets[innermostDimIdx - 1]);
222+ #else
223+ // order=2,1,0 --> idx = 2 (row major) --> idx we want = 0
224+ // order=2,0,1 --> idx = 1 (column major) --> idx we want == 1
225+
226+ unsigned newInnermostDimIdx = (innermostDimIdx - 1 );
227+ unsigned newOutermostDimIdx = !newInnermostDimIdx;
214228
229+ SmallVector<Value> newShape (makeTensorPtrOp.getShape ().drop_front ());
230+ SmallVector<Value> newStrides (makeTensorPtrOp.getStrides ().drop_front ());
231+ SmallVector<Value> newOffsets (makeTensorPtrOp.getOffsets ().drop_front ());
232+
233+ auto div = builder.create <arith::DivUIOp>(loc, strides[0 ],
234+ newStrides[newOutermostDimIdx]);
235+ newShape[newOutermostDimIdx] = builder.create <arith::AddIOp>(
236+ loc, builder.create <arith::MulIOp>(loc, shapes[0 ], div),
237+ newShape[newOutermostDimIdx]);
238+ newOffsets[newOutermostDimIdx] = builder.create <arith::AddIOp>(
239+ loc,
240+ builder.create <arith::MulIOp>(
241+ loc, offsets[0 ],
242+ builder.create <arith::TruncIOp>(loc, offsets[0 ].getType (), div)),
243+ newOffsets[newOutermostDimIdx]);
244+ #endif
215245 Value ptr = builder.create <tt::MakeTensorPtrOp>(
216246 loc, newPtrType, makeTensorPtrOp.getBase (), newShape, newStrides,
217247 newOffsets,
0 commit comments