Skip to content

Commit 563a488

Browse files
committed
Address reviewer comments.
1 parent 02bd06c commit 563a488

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,9 @@ struct ConvertXeGPUToXeVMPass
882882
return {};
883883
};
884884

885+
// If result type of original op is single element vector and lowered type
886+
// is scalar. This materialization cast creates a single element vector by
887+
// broadcasting the scalar value.
885888
auto singleElementVectorMaterializationCast =
886889
[](OpBuilder &builder, Type type, ValueRange inputs,
887890
Location loc) -> Value {

mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,20 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>)
2727
}
2828

2929
// -----
30-
module @test {
30+
gpu.module @test {
3131
// CHECK-LABEL: @source_materialize_single_elem_vec
32-
func.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>) -> vector<1xf16> {
32+
// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16>
33+
gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>) {
3334
%1 = arith.constant dense<1>: vector<1xi1>
3435
%3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
3536
: i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
3637
// CHECK: %[[VAR_IF:.*]] = scf.if
3738
// CHECK: %[[VAR_RET:.*]] = vector.broadcast %[[VAR_IF]] : f16 to vector<1xf16>
38-
// CHECK: return %[[VAR_RET]] : vector<1xf16>
39-
return %3 : vector<1xf16>
39+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
40+
// CHECK: vector.store %[[VAR_RET]], %[[ARG2]][%[[C0]]] : memref<1xf16>, vector<1xf16>
41+
%c0 = arith.constant 0 : index
42+
vector.store %3, %dst[%c0] : memref<1xf16>, vector<1xf16>
43+
gpu.return
4044
}
4145
}
4246

0 commit comments

Comments
 (0)