@@ -151,10 +151,11 @@ class CreateNdDescToXeVMPattern
151151 auto loc = op.getLoc ();
152152 auto source = op.getSource ();
153153 // Op is lowered to a code sequence that populates payload.
154- // Payload is a 8xi32 vector.
154+ // Payload is a 8xi32 vector. Offset to individual fields are defined in
155+ // NdTdescOffset enum.
155156 Type payloadElemTy = rewriter.getI32Type ();
156- Type i64Ty = rewriter.getI64Type ();
157157 VectorType payloadTy = VectorType::get (8 , payloadElemTy);
158+ Type i64Ty = rewriter.getI64Type ();
158159 // 4xi64 view is used for inserting the base pointer.
159160 VectorType payloadI64Ty = VectorType::get (4 , i64Ty);
160161 // Initialize payload to zero.
@@ -180,12 +181,12 @@ class CreateNdDescToXeVMPattern
180181 // If source is a memref, we need to extract the aligned pointer as index.
181182 // Pointer type is passed as i32 or i64 by type converter.
182183 if (sourceMemrefTy) {
183- baseAddr =
184- memref::ExtractAlignedPointerAsIndexOp::create (rewriter, loc, source);
185184 if (!sourceMemrefTy.hasStaticShape ()) {
186185 op.emitError () << " Expected static memref shape." ;
187186 return failure ();
188187 }
188+ baseAddr =
189+ memref::ExtractAlignedPointerAsIndexOp::create (rewriter, loc, source);
189190 } else {
190191 baseAddr = adaptor.getSource ();
191192 }
@@ -198,8 +199,8 @@ class CreateNdDescToXeVMPattern
198199 };
199200 // Offsets can be either 2D or not provided (0 is used).
200201 if (mixedOffsets.size () == 2 ) {
201- offsetW = createOffset (mixedOffsets, rank - 1 );
202- offsetH = createOffset (mixedOffsets, rank - 2 );
202+ offsetW = createOffset (mixedOffsets, 1 );
203+ offsetH = createOffset (mixedOffsets, 0 );
203204 } else if (mixedOffsets.size () == 0 ) {
204205 offsetW = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
205206 offsetH = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
@@ -208,8 +209,8 @@ class CreateNdDescToXeVMPattern
208209 " Expected 2D offsets or no offsets." );
209210 }
210211 // Get shape values from op fold results.
211- baseShapeW = createOffset (mixedSizes, rank - 1 );
212- baseShapeH = createOffset (mixedSizes, rank - 2 );
212+ baseShapeW = createOffset (mixedSizes, 1 );
213+ baseShapeH = createOffset (mixedSizes, 0 );
213214 if (sourceMemrefTy)
214215 // Cast index to i64.
215216 baseAddr = arith::IndexCastUIOp::create (rewriter, loc, i64Ty, baseAddr);
0 commit comments