@@ -179,48 +179,49 @@ class CreateNdDescToXeVMPattern
179
179
Value baseShapeH;
180
180
Value offsetW;
181
181
Value offsetH;
182
+ auto convertToValue = [&](OpFoldResult ofr) -> Value {
183
+ Value val;
184
+ if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
185
+ val = rewriter.create <arith::IndexCastOp>(loc, i64Ty, v);
186
+ val = rewriter.create <arith::TruncIOp>(loc, payloadElemTy, val);
187
+ } else {
188
+ int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt ();
189
+ val = rewriter.create <arith::ConstantIntOp>(loc, payloadElemTy, off);
190
+ }
191
+ return val;
192
+ };
193
+
194
+ int rank = op.getMixedOffsets ().size ();
195
+ if (rank != 2 ) {
196
+ op.emitError () << " Expected 2D offsets, got " << rank << " D offsets." ;
197
+ return mlir::failure ();
198
+ }
199
+ offsetW = convertToValue (op.getMixedOffsets ()[rank - 1 ]);
200
+ offsetH = convertToValue (op.getMixedOffsets ()[rank - 2 ]);
182
201
183
202
if (auto sourceTy = source.getType (); isa<MemRefType>(sourceTy)) {
184
203
baseAddr =
185
204
rewriter.create <memref::ExtractAlignedPointerAsIndexOp>(loc, source);
205
+ baseAddr = rewriter.create <arith::IndexCastUIOp>(loc, i64Ty, baseAddr);
186
206
auto sourceMemrefTy = cast<MemRefType>(sourceTy);
187
207
if (!sourceMemrefTy.hasStaticShape ()) {
188
208
op.emitError () << " Expected static memref shape." ;
189
209
return mlir::failure ();
190
210
}
191
211
auto rank = sourceMemrefTy.getRank ();
192
- if (rank != 2 ) {
193
- op.emitError () << " Expected a 2D memref." ;
194
- return mlir::failure ();
195
- }
196
- auto createOffset = [&](unsigned idx) -> Value {
197
- Value val;
198
- OpFoldResult ofr = op.getMixedOffsets ()[idx];
199
- if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
200
- val = rewriter.create <arith::IndexCastOp>(loc, i64Ty, v);
201
- val = rewriter.create <arith::TruncIOp>(loc, payloadElemTy, val);
202
- } else {
203
- int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt ();
204
- val = rewriter.create <arith::ConstantIntOp>(loc, payloadElemTy, off);
205
- }
206
- return val;
207
- };
208
- offsetW = createOffset (rank - 1 );
209
- offsetH = createOffset (rank - 2 );
210
212
baseShapeW = rewriter.create <arith::ConstantIntOp>(
211
213
loc, payloadElemTy, sourceMemrefTy.getDimSize (rank - 1 ));
212
214
baseShapeH = rewriter.create <arith::ConstantIntOp>(
213
215
loc, payloadElemTy, sourceMemrefTy.getDimSize (rank - 2 ));
214
216
} else if (isa<IntegerType>(sourceTy)) {
215
- op. emitError ()
216
- << " Integer as source are currently not supported by the pass. " ;
217
- return mlir::failure ( );
217
+ baseAddr = source;
218
+ baseShapeW = convertToValue (op. getMixedSizes ()[rank - 1 ]) ;
219
+ baseShapeH = convertToValue (op. getMixedSizes ()[rank - 2 ] );
218
220
} else {
219
221
op.emitError () << " Unknown source type." ;
220
222
return mlir::failure ();
221
223
}
222
224
223
- baseAddr = rewriter.create <arith::IndexCastUIOp>(loc, i64Ty, baseAddr);
224
225
Value payLoadAsI64 =
225
226
rewriter.create <vector::BitCastOp>(loc, payloadI64Ty, payload);
226
227
payLoadAsI64 = rewriter.create <vector::InsertOp>(
0 commit comments