Skip to content

Commit 422c2a5

Browse files
authored
[GPUToSPIRV] Adding support for vector.constant_mask (#1071)
Leverage upstream pattern for vector.constant_mask lowering to SPIR-V
1 parent 156dfc4 commit 422c2a5

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include <mlir/Dialect/SPIRV/IR/SPIRVOps.h>
3737
#include <mlir/Dialect/SPIRV/IR/SPIRVTypes.h>
3838
#include <mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h>
39+
#include <mlir/Dialect/Vector/Transforms/LoweringPatterns.h>
3940
#include <mlir/Dialect/XeGPU/IR/XeGPU.h>
4041
#include <mlir/IR/BuiltinOps.h>
4142
#include <mlir/IR/Matchers.h>
@@ -371,6 +372,10 @@ void GPUXToSPIRVPass::runOnOperation() {
371372
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
372373
mlir::populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
373374
mlir::populateVectorToSPIRVPatterns(typeConverter, patterns);
375+
// this is for lowering of ConstantMaskOpLowering
376+
// but we should also consider replacing VectorMaskConversionPattern
377+
// with the upstream one
378+
mlir::vector::populateVectorMaskOpLoweringPatterns(patterns);
374379
mlir::populateMathToSPIRVPatterns(typeConverter, patterns);
375380
mlir::index::populateIndexToSPIRVPatterns(typeConverter, patterns);
376381
mlir::populateMemRefToSPIRVPatterns(typeConverter, patterns);
@@ -383,7 +388,6 @@ void GPUXToSPIRVPass::runOnOperation() {
383388
mlir::populateSCFToSPIRVPatterns(typeConverter, scfToSpirvCtx, patterns);
384389
mlir::cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
385390
mlir::populateMathToSPIRVPatterns(typeConverter, patterns);
386-
// for ub.poison op with vector operand
387391
imex::populateVectorToSPIRVPatterns(typeConverter, patterns);
388392

389393
if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))

test/Conversion/GPUToSPIRV/create_mask.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,36 @@ module attributes {
3030
// CHECK-NEXT: %[[CAST:.*]] = spirv.SConvert %[[SELECT2]] : i32 to i16
3131
// CHECK-NEXT: spirv.Bitcast %[[CAST]] : i16 to vector<16xi1>
3232
// CHECK-NEXT: spirv.Return
33+
34+
// -----
35+
36+
module attributes {
37+
gpu.container_module,
38+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
39+
[Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR],
40+
[SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, #spirv.resource_limits<>>
41+
} {
42+
gpu.module @kernels {
43+
// CHECK-LABEL: spirv.func @constant_mask_0
44+
gpu.func @constant_mask_0() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
45+
%0 = vector.constant_mask [0] : vector<16xi1>
46+
// CHECK-NEXT: spirv.Constant dense<false> : vector<16xi1>
47+
// CHECK-NEXT: spirv.Return
48+
gpu.return
49+
}
50+
// CHECK-LABEL: spirv.func @constant_mask_7
51+
gpu.func @constant_mask_7() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
52+
%7 = vector.constant_mask [7] : vector<16xi1>
53+
// CHECK-NEXT: spirv.Constant dense<[true, true, true, true, true, true, true, false, false, false, false, false, false, false, false, false]> : vector<16xi1>
54+
// CHECK-NEXT: spirv.Return
55+
gpu.return
56+
}
57+
// CHECK-LABEL: spirv.func @constant_mask_16
58+
gpu.func @constant_mask_16() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
59+
%16 = vector.constant_mask [16] : vector<16xi1>
60+
// CHECK-NEXT: spirv.Constant dense<true> : vector<16xi1>
61+
// CHECK-NEXT: spirv.Return
62+
gpu.return
63+
}
64+
}
65+
}

0 commit comments

Comments
 (0)