diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index 4891dab3aa1d0..c6c695b442b4f 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -136,9 +136,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + Value initShflValue = adaptor.getValue(); + Type shflType = initShflValue.getType(); // TODO: Add support for non 32-bit shuffle values. - if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32) - return failure(); + if (!shflType.isIntOrFloat() || shflType.getIntOrFloatBitWidth() != 32) + return rewriter.notifyMatchFailure( + op, "only 32-bit int/float types are supported"); + const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth); @@ -175,16 +179,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { Value two = rewriter.create(loc, int32Type, 2); Value dwordAlignedDstLane = rewriter.create(loc, int32Type, selectDstLane, two); - Value initShflValue = adaptor.getValue(); - if (adaptor.getValue().getType().isF32()) { + if (shflType.isF32()) { initShflValue = rewriter.create(loc, int32Type, initShflValue); } Value shflValue = rewriter.create( loc, int32Type, dwordAlignedDstLane, initShflValue); - if (adaptor.getValue().getType().isF32()) { - shflValue = rewriter.create( - loc, adaptor.getValue().getType(), shflValue); + if (shflType.isF32()) { + shflValue = rewriter.create(loc, shflType, shflValue); } rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); return success(); diff --git a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp index 4bd4da25f6e52..9f2900214e8b1 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp @@ -40,8 +40,9 @@ struct GpuShuffleRewriter : public OpRewritePattern { auto i64 = rewriter.getI64Type(); // If the type of the value is either i32 or f32, the op is already valid. - if (valueType.getIntOrFloatBitWidth() == 32) - return failure(); + if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 64) + return rewriter.notifyMatchFailure( + op, "only 64-bit int/float types are supported"); Value lo, hi; diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir new file mode 100644 index 0000000000000..90f2e5f047cd9 --- /dev/null +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -convert-gpu-to-rocdl -verify-diagnostics + +gpu.module @test_module { + // ROCDL lowering only suport shuffles for 32bit ints/floats, but they + // shouldn't crash on unsupported types. + func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> { + %offset = arith.constant 4 : i32 + %width = arith.constant 64 : i32 + // expected-error @+1 {{failed to legalize operation 'gpu.shuffle'}} + %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16> + return %shfl : vector<4xf16> + } +} diff --git a/mlir/test/Dialect/GPU/shuffle-rewrite.mlir b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir index 4618258201532..c0ccae05a0572 100644 --- a/mlir/test/Dialect/GPU/shuffle-rewrite.mlir +++ b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir @@ -49,3 +49,14 @@ module { return } } + +// ----- + +// CHECK-LABEL: @gpu_shuffle_unsupported +func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> { + %offset = arith.constant 4 : i32 + %width = arith.constant 64 : i32 + // CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : vector<4xf16> + %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16> + return %shfl : vector<4xf16> +}