Skip to content

Commit d60d038

Browse files
authored
[MLIR][Conversion] XeGPU to XeVM: Remove unused type converter source materializations. (#162947)
And add source materialization for single element vector.
1 parent e712871 commit d60d038

File tree

2 files changed

+50
-11
lines changed

2 files changed

+50
-11
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2121
#include "mlir/Dialect/SCF/IR/SCF.h"
2222
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
23+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2324
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
2425
#include "mlir/Pass/Pass.h"
2526
#include "mlir/Support/LLVM.h"
@@ -390,7 +391,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
390391
// Load result or Store valye Type can be vector or scalar.
391392
Type valOrResTy;
392393
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
393-
valOrResTy = op.getResult().getType();
394+
valOrResTy =
395+
this->getTypeConverter()->convertType(op.getResult().getType());
394396
else
395397
valOrResTy = adaptor.getValue().getType();
396398
VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
@@ -878,10 +880,30 @@ struct ConvertXeGPUToXeVMPass
878880
}
879881
return {};
880882
};
881-
typeConverter.addSourceMaterialization(memrefMaterializationCast);
882-
typeConverter.addSourceMaterialization(ui64MaterializationCast);
883-
typeConverter.addSourceMaterialization(ui32MaterializationCast);
884-
typeConverter.addSourceMaterialization(vectorMaterializationCast);
883+
884+
// If result type of original op is single element vector and lowered type
885+
// is scalar. This materialization cast creates a single element vector by
886+
// broadcasting the scalar value.
887+
auto singleElementVectorMaterializationCast =
888+
[](OpBuilder &builder, Type type, ValueRange inputs,
889+
Location loc) -> Value {
890+
if (inputs.size() != 1)
891+
return {};
892+
auto input = inputs.front();
893+
if (input.getType().isIntOrIndexOrFloat()) {
894+
// If the input is a scalar, and the target type is a vector of single
895+
// element, create a single element vector by broadcasting.
896+
if (auto vecTy = dyn_cast<VectorType>(type)) {
897+
if (vecTy.getNumElements() == 1) {
898+
return vector::BroadcastOp::create(builder, loc, vecTy, input)
899+
.getResult();
900+
}
901+
}
902+
}
903+
return {};
904+
};
905+
typeConverter.addSourceMaterialization(
906+
singleElementVectorMaterializationCast);
885907
typeConverter.addTargetMaterialization(memrefMaterializationCast);
886908
typeConverter.addTargetMaterialization(ui32MaterializationCast);
887909
typeConverter.addTargetMaterialization(ui64MaterializationCast);

mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,36 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>)
1414
// CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64
1515
// CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
1616
// CHECK: %[[VAR6:.*]] = scf.if %[[VAR2]] -> (f16) {
17-
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> vector<1xf16>
18-
// CHECK: %[[VAR8:.*]] = vector.extract %[[VAR7]][0] : f16 from vector<1xf16>
19-
// CHECK: scf.yield %[[VAR8]] : f16
20-
// CHECK: } else {
21-
// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf16>
22-
// CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f16 from vector<1xf16>
17+
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> f16
2318
// CHECK: scf.yield %[[VAR7]] : f16
19+
// CHECK: } else {
20+
// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16
21+
// CHECK: scf.yield %[[CST_0]] : f16
2422
// CHECK: }
2523
%3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
2624
: i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
2725
gpu.return
2826
}
2927
}
28+
29+
// -----
30+
gpu.module @test {
31+
// CHECK-LABEL: @source_materialize_single_elem_vec
32+
// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16>
33+
gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>) {
34+
%1 = arith.constant dense<1>: vector<1xi1>
35+
%3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
36+
: i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
37+
// CHECK: %[[VAR_IF:.*]] = scf.if
38+
// CHECK: %[[VAR_RET:.*]] = vector.broadcast %[[VAR_IF]] : f16 to vector<1xf16>
39+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
40+
// CHECK: vector.store %[[VAR_RET]], %[[ARG2]][%[[C0]]] : memref<1xf16>, vector<1xf16>
41+
%c0 = arith.constant 0 : index
42+
vector.store %3, %dst[%c0] : memref<1xf16>, vector<1xf16>
43+
gpu.return
44+
}
45+
}
46+
3047
// -----
3148

3249
gpu.module @test {

0 commit comments

Comments
 (0)