Skip to content

Commit 02bd06c

Browse files
committed
Remove unused type converter source materializations.
Add source materialization for single element vector.
1 parent b92483c commit 02bd06c

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2121
#include "mlir/Dialect/SCF/IR/SCF.h"
2222
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
23+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2324
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
2425
#include "mlir/Pass/Pass.h"
2526
#include "mlir/Support/LLVM.h"
@@ -389,7 +390,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
389390
// Load result or Store valye Type can be vector or scalar.
390391
Type valOrResTy;
391392
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
392-
valOrResTy = op.getResult().getType();
393+
valOrResTy =
394+
this->getTypeConverter()->convertType(op.getResult().getType());
393395
else
394396
valOrResTy = adaptor.getValue().getType();
395397
VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
@@ -879,10 +881,27 @@ struct ConvertXeGPUToXeVMPass
879881
}
880882
return {};
881883
};
882-
typeConverter.addSourceMaterialization(memrefMaterializationCast);
883-
typeConverter.addSourceMaterialization(ui64MaterializationCast);
884-
typeConverter.addSourceMaterialization(ui32MaterializationCast);
885-
typeConverter.addSourceMaterialization(vectorMaterializationCast);
884+
885+
auto singleElementVectorMaterializationCast =
886+
[](OpBuilder &builder, Type type, ValueRange inputs,
887+
Location loc) -> Value {
888+
if (inputs.size() != 1)
889+
return {};
890+
auto input = inputs.front();
891+
if (input.getType().isIntOrIndexOrFloat()) {
892+
// If the input is a scalar, and the target type is a vector of single
893+
// element, create a single element vector by broadcasting.
894+
if (auto vecTy = dyn_cast<VectorType>(type)) {
895+
if (vecTy.getNumElements() == 1) {
896+
return vector::BroadcastOp::create(builder, loc, vecTy, input)
897+
.getResult();
898+
}
899+
}
900+
}
901+
return {};
902+
};
903+
typeConverter.addSourceMaterialization(
904+
singleElementVectorMaterializationCast);
886905
typeConverter.addTargetMaterialization(memrefMaterializationCast);
887906
typeConverter.addTargetMaterialization(ui32MaterializationCast);
888907
typeConverter.addTargetMaterialization(ui64MaterializationCast);

mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,32 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>)
1414
// CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64
1515
// CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
1616
// CHECK: %[[VAR6:.*]] = scf.if %[[VAR2]] -> (f16) {
17-
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> vector<1xf16>
18-
// CHECK: %[[VAR8:.*]] = vector.extract %[[VAR7]][0] : f16 from vector<1xf16>
19-
// CHECK: scf.yield %[[VAR8]] : f16
20-
// CHECK: } else {
21-
// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf16>
22-
// CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f16 from vector<1xf16>
17+
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> f16
2318
// CHECK: scf.yield %[[VAR7]] : f16
19+
// CHECK: } else {
20+
// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16
21+
// CHECK: scf.yield %[[CST_0]] : f16
2422
// CHECK: }
2523
%3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
2624
: i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
2725
gpu.return
2826
}
2927
}
28+
29+
// -----
30+
module @test {
31+
// CHECK-LABEL: @source_materialize_single_elem_vec
32+
func.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>) -> vector<1xf16> {
33+
%1 = arith.constant dense<1>: vector<1xi1>
34+
%3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
35+
: i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
36+
// CHECK: %[[VAR_IF:.*]] = scf.if
37+
// CHECK: %[[VAR_RET:.*]] = vector.broadcast %[[VAR_IF]] : f16 to vector<1xf16>
38+
// CHECK: return %[[VAR_RET]] : vector<1xf16>
39+
return %3 : vector<1xf16>
40+
}
41+
}
42+
3043
// -----
3144

3245
gpu.module @test {

0 commit comments

Comments
 (0)