@@ -48,15 +48,6 @@ namespace {
4848static constexpr int32_t systolicDepth{8 };
4949static constexpr int32_t executionSize{16 };
5050
51- // Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
52- enum class NdTdescOffset : uint32_t {
53- BasePtr = 0 , // Base pointer (i64)
54- BaseShapeW = 2 , // Base shape width (i32)
55- BaseShapeH = 3 , // Base shape height (i32)
56- TensorOffsetW = 4 , // Tensor offset W (i32)
57- TensorOffsetH = 5 // Tensor offset H (i32)
58- };
59-
6051static int32_t getNumericXeVMAddrSpace (xegpu::MemorySpace xeGpuMemspace) {
6152 switch (xeGpuMemspace) {
6253 case xegpu::MemorySpace::Global:
@@ -177,92 +168,14 @@ class CreateNdDescToXeVMPattern
177168 if (mixedOffsets.size () != 0 )
178169 return rewriter.notifyMatchFailure (op, " Offsets not supported." );
179170 auto loc = op.getLoc ();
180- auto source = op.getSource ();
181- // Op is lowered to a code sequence that populates payload.
182- // Payload is a 8xi32 vector. Offset to individual fields are defined in
183- // NdTdescOffset enum.
184- Type payloadElemTy = rewriter.getI32Type ();
185- VectorType payloadTy = VectorType::get (8 , payloadElemTy);
186- Type i64Ty = rewriter.getI64Type ();
187- // 4xi64 view is used for inserting the base pointer.
188- VectorType payloadI64Ty = VectorType::get (4 , i64Ty);
189- // Initialize payload to zero.
190- Value payload = arith::ConstantOp::create (
191- rewriter, loc,
192- DenseElementsAttr::get (payloadTy, IntegerAttr::get (payloadElemTy, 0 )));
193-
194- Value baseAddr;
195- Value baseShapeW;
196- Value baseShapeH;
197- Value offsetW;
198- Value offsetH;
199171
200- // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
201- SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes ();
202- auto srcRank = mixedSizes.size ();
203- if (srcRank < 2 )
204- return rewriter.notifyMatchFailure (op, " Expected at least 2D source." );
205-
206- auto sourceTy = source.getType ();
207- auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
208- // If source is a memref, we need to extract the aligned pointer as index.
209- // Pointer type is passed as i32 or i64 by type converter.
210- if (sourceMemrefTy) {
211- if (!sourceMemrefTy.hasStaticShape ()) {
212- return rewriter.notifyMatchFailure (op, " Expected static memref shape." );
213- }
214- baseAddr =
215- memref::ExtractAlignedPointerAsIndexOp::create (rewriter, loc, source);
216- } else {
217- baseAddr = adaptor.getSource ();
218- }
219- // Utility for creating offset values from op fold result.
220- auto createOffset = [&](OpFoldResult ofr) -> Value {
221- Value val = getValueOrCreateConstantIntOp (rewriter, loc, ofr);
222- val = getValueOrCreateCastToIndexLike (rewriter, loc, payloadElemTy, val);
223- return val;
224- };
225- // Offsets are not supported (0 is used).
226- offsetW = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
227- offsetH = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
228- // Get shape values from op fold results.
229- baseShapeW = createOffset (mixedSizes[srcRank - 1 ]);
230- if (srcRank == 2 ) {
231- baseShapeH = createOffset (mixedSizes[0 ]);
232- } else {
233- // Generate compute chain for height (product of sizes of all but the last
234- // dimension).
235- baseShapeH = getProductOfSizes (rewriter, loc, mixedSizes, 0 , srcRank - 1 );
236- baseShapeH = getValueOrCreateCastToIndexLike (rewriter, loc, payloadElemTy,
237- baseShapeH);
238- }
239- if (sourceMemrefTy) {
240- // Cast index to i64.
241- baseAddr = arith::IndexCastUIOp::create (rewriter, loc, i64Ty, baseAddr);
242- } else if (baseAddr.getType () != i64Ty) {
172+ Value baseAddr = adaptor.getSource ();
173+ Type i64Ty = rewriter.getI64Type ();
174+ if (baseAddr.getType () != i64Ty) {
243175 // Pointer type may be i32. Cast to i64 if needed.
244176 baseAddr = arith::ExtUIOp::create (rewriter, loc, i64Ty, baseAddr);
245177 }
246- // Populate payload.
247- Value payLoadAsI64 =
248- vector::BitCastOp::create (rewriter, loc, payloadI64Ty, payload);
249- payLoadAsI64 =
250- vector::InsertOp::create (rewriter, loc, baseAddr, payLoadAsI64,
251- static_cast <int >(NdTdescOffset::BasePtr));
252- payload = vector::BitCastOp::create (rewriter, loc, payloadTy, payLoadAsI64);
253- payload =
254- vector::InsertOp::create (rewriter, loc, baseShapeW, payload,
255- static_cast <int >(NdTdescOffset::BaseShapeW));
256- payload =
257- vector::InsertOp::create (rewriter, loc, baseShapeH, payload,
258- static_cast <int >(NdTdescOffset::BaseShapeH));
259- payload = vector::InsertOp::create (
260- rewriter, loc, offsetW, payload,
261- static_cast <int >(NdTdescOffset::TensorOffsetW));
262- payload = vector::InsertOp::create (
263- rewriter, loc, offsetH, payload,
264- static_cast <int >(NdTdescOffset::TensorOffsetH));
265- rewriter.replaceOp (op, payload);
178+ rewriter.replaceOp (op, baseAddr);
266179 return success ();
267180 }
268181};
@@ -291,7 +204,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
291204 auto loc = op.getLoc ();
292205 auto ctxt = rewriter.getContext ();
293206
294- auto tdesc = adaptor.getTensorDesc ();
295207 auto tdescTy = op.getTensorDescType ();
296208 if (tdescTy.getRank () != 2 )
297209 return rewriter.notifyMatchFailure (op, " Expected 2D tensor descriptor." );
@@ -301,15 +213,27 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
301213 return rewriter.notifyMatchFailure (
302214 op, " Expected element type bit width to be multiple of 8." );
303215
304- VectorType payloadI64Ty = VectorType::get (4 , rewriter.getI64Type ());
305- Value payLoadAsI64 =
306- vector::BitCastOp::create (rewriter, loc, payloadI64Ty, tdesc);
307- Value basePtr = vector::ExtractOp::create (
308- rewriter, loc, payLoadAsI64, static_cast <int >(NdTdescOffset::BasePtr));
309- Value baseShapeW = vector::ExtractOp::create (
310- rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeW));
311- Value baseShapeH = vector::ExtractOp::create (
312- rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeH));
216+ Value basePtr = adaptor.getTensorDesc ();
217+ // Utility for creating offset values from op fold result.
218+ Type payloadElemTy = rewriter.getIntegerType (32 );
219+ auto createOffset = [&](OpFoldResult ofr) -> Value {
220+ Value val = getValueOrCreateConstantIntOp (rewriter, loc, ofr);
221+ val = getValueOrCreateCastToIndexLike (rewriter, loc, payloadElemTy, val);
222+ return val;
223+ };
224+ auto srcRank = mixedSizes.size ();
225+ // Get shape values from op fold results.
226+ Value baseShapeW = createOffset (mixedSizes[srcRank - 1 ]);
227+ Value baseShapeH;
228+ if (srcRank == 2 ) {
229+ baseShapeH = createOffset (mixedSizes[0 ]);
230+ } else {
231+ // Generate compute chain for height (product of sizes of all but the last
232+ // dimension).
233+ baseShapeH = getProductOfSizes (rewriter, loc, mixedSizes, 0 , srcRank - 1 );
234+ baseShapeH = getValueOrCreateCastToIndexLike (rewriter, loc, payloadElemTy,
235+ baseShapeH);
236+ }
313237 // Offsets are provided by the op.
314238 // convert them to i32.
315239 // Offset computation assumes base memory layout is row major.
@@ -979,10 +903,7 @@ struct ConvertXeGPUToXeVMPass
979903 return VectorType::get (sum, elemType);
980904 });
981905 typeConverter.addConversion ([&](xegpu::TensorDescType type) -> Type {
982- if (type.isScattered ())
983- return IntegerType::get (&getContext (), 64 );
984- auto i32Type = IntegerType::get (&getContext (), 32 );
985- return VectorType::get (8 , i32Type);
906+ return IntegerType::get (&getContext (), 64 );
986907 });
987908 // Convert MemDescType into flattened MemRefType for SLM
988909 typeConverter.addConversion ([&](xegpu::MemDescType type) -> Type {
0 commit comments