Skip to content

Commit 4fe1bd5

Browse files
authored
[mlir][spirv][vector] Use adaptor.getElements() in FromElements lowering. (#156972)
Signed-off-by: hanhanW <[email protected]>
1 parent 556ff19 commit 4fe1bd5

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ struct VectorFromElementsOpConvert final
278278
Type resultType = getTypeConverter()->convertType(op.getType());
279279
if (!resultType)
280280
return failure();
281-
OperandRange elements = op.getElements();
281+
ValueRange elements = adaptor.getElements();
282282
if (isa<spirv::ScalarType>(resultType)) {
283283
// In the case with a single scalar operand / single-element result,
284284
// pass through the scalar.

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,33 +281,46 @@ func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
281281

282282
// -----
283283

284-
// CHECK-LABEL: @from_elements_0d
284+
// CHECK-LABEL: @from_elements_0d_f32
285285
// CHECK-SAME: %[[ARG0:.+]]: f32
286286
// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
287287
// CHECK: return %[[RETVAL]]
288-
func.func @from_elements_0d(%arg0 : f32) -> vector<f32> {
288+
func.func @from_elements_0d_f32(%arg0 : f32) -> vector<f32> {
289289
%0 = vector.from_elements %arg0 : vector<f32>
290290
return %0: vector<f32>
291291
}
292292

293-
// CHECK-LABEL: @from_elements_1x
293+
// CHECK-LABEL: @from_elements_1xf32
294294
// CHECK-SAME: %[[ARG0:.+]]: f32
295295
// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
296296
// CHECK: return %[[RETVAL]]
297-
func.func @from_elements_1x(%arg0 : f32) -> vector<1xf32> {
297+
func.func @from_elements_1xf32(%arg0 : f32) -> vector<1xf32> {
298298
%0 = vector.from_elements %arg0 : vector<1xf32>
299299
return %0: vector<1xf32>
300300
}
301301

302-
// CHECK-LABEL: @from_elements_3x
302+
// CHECK-LABEL: @from_elements_3xf32
303303
// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32
304304
// CHECK: %[[RETVAL:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG2]] : (f32, f32, f32) -> vector<3xf32>
305305
// CHECK: return %[[RETVAL]]
306-
func.func @from_elements_3x(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> {
306+
func.func @from_elements_3xf32(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> {
307307
%0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xf32>
308308
return %0: vector<3xf32>
309309
}
310310

311+
func.func @from_elements_3xi8(%arg0 : i8, %arg1 : i8, %arg2 : i8) -> vector<3xi8> {
312+
%0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xi8>
313+
return %0: vector<3xi8>
314+
}
315+
// CHECK-LABEL: @from_elements_3xi8
316+
// CHECK-SAME: %[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8, %[[ARG2:.+]]: i8
317+
// CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : i8 to i32
318+
// CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
319+
// CHECK-DAG: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : i8 to i32
320+
// CHECK: %[[VAL:.+]] = spirv.CompositeConstruct %[[CAST0]], %[[CAST1]], %[[CAST2]] : (i32, i32, i32) -> vector<3xi32>
321+
// CHECK: %[[RETVAL:.*]] = builtin.unrealized_conversion_cast %[[VAL]] : vector<3xi32> to vector<3xi8>
322+
// CHECK: return %[[RETVAL]]
323+
311324
// -----
312325

313326
// CHECK-LABEL: @insert

0 commit comments

Comments
 (0)