Skip to content

Commit dea2933

Browse files
committed
Add element bitwidth restriction.
1 parent 10a6aff commit dea2933

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
290290
auto tdescTy = op.getTensorDescType();
291291
if (tdescTy.getRank() != 2)
292292
return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
293+
auto elemType = tdescTy.getElementType();
294+
auto elemBitSize = elemType.getIntOrFloatBitWidth();
295+
if (elemBitSize % 8 != 0)
296+
return rewriter.notifyMatchFailure(
297+
op, "Expected element type bit width to be multiple of 8.");
293298

294299
VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
295300
Value payLoadAsI64 =
@@ -333,8 +338,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
333338
Value basePtrLLVM =
334339
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
335340
// Compute element byte size and surface width in bytes.
336-
auto elemType = tdescTy.getElementType();
337-
auto elemBitSize = elemType.getIntOrFloatBitWidth();
338341
Value elemByteSize = arith::ConstantIntOp::create(
339342
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
340343
Value surfaceW =

0 commit comments

Comments
 (0)