@@ -79,6 +79,25 @@ func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -
7979// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
8080
8181
82+ // -----
83+
84+ func.func @masked_reduce_add_f32_scalable (%arg0: vector <[16 ]xf32 >, %mask : vector <[16 ]xi1 >) -> f32 {
85+ %0 = vector.mask %mask { vector.reduction <add >, %arg0 : vector <[16 ]xf32 > into f32 } : vector <[16 ]xi1 > -> f32
86+ return %0 : f32
87+ }
88+
89+ // CHECK-LABEL: func.func @masked_reduce_add_f32_scalable(
90+ // CHECK-SAME: %[[INPUT:.*]]: vector<[16]xf32>,
91+ // CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) -> f32 {
92+ // CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
93+ // CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32
94+ // CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
95+ // CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
96+ // CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
97+ // CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
98+ // CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32
99+
100+
82101// -----
83102
84103func.func @masked_reduce_mul_f32 (%arg0: vector <16 xf32 >, %mask : vector <16 xi1 >) -> f32 {
@@ -110,6 +129,24 @@ func.func @masked_reduce_minf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>)
110129
111130// -----
112131
132+ func.func @masked_reduce_minf_f32_scalable (%arg0: vector <[16 ]xf32 >, %mask : vector <[16 ]xi1 >) -> f32 {
133+ %0 = vector.mask %mask { vector.reduction <minnumf >, %arg0 : vector <[16 ]xf32 > into f32 } : vector <[16 ]xi1 > -> f32
134+ return %0 : f32
135+ }
136+
137+ // CHECK-LABEL: func.func @masked_reduce_minf_f32_scalable(
138+ // CHECK-SAME: %[[INPUT:.*]]: vector<[16]xf32>,
139+ // CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) -> f32 {
140+ // CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0xFFC00000 : f32) : f32
141+ // CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32
142+ // CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
143+ // CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
144+ // CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
145+ // CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
146+ // CHECK: "llvm.intr.vp.reduce.fmin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32
147+
148+ // -----
149+
113150func.func @masked_reduce_maxf_f32 (%arg0: vector <16 xf32 >, %mask : vector <16 xi1 >) -> f32 {
114151 %0 = vector.mask %mask { vector.reduction <maxnumf >, %arg0 : vector <16 xf32 > into f32 } : vector <16 xi1 > -> f32
115152 return %0 : f32
@@ -167,6 +204,25 @@ func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) ->
167204// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
168205
169206
207+ // -----
208+
209+ func.func @masked_reduce_add_i8_scalable (%arg0: vector <[32 ]xi8 >, %mask : vector <[32 ]xi1 >) -> i8 {
210+ %0 = vector.mask %mask { vector.reduction <add >, %arg0 : vector <[32 ]xi8 > into i8 } : vector <[32 ]xi1 > -> i8
211+ return %0 : i8
212+ }
213+
214+ // CHECK-LABEL: func.func @masked_reduce_add_i8_scalable(
215+ // CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
216+ // CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
217+ // CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
218+ // CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
219+ // CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
220+ // CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
221+ // CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
222+ // CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
223+ // CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
224+
225+
170226// -----
171227
172228func.func @masked_reduce_mul_i8 (%arg0: vector <32 xi8 >, %mask : vector <32 xi1 >) -> i8 {
@@ -197,6 +253,24 @@ func.func @masked_reduce_minui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -
197253
198254// -----
199255
256+ func.func @masked_reduce_minui_i8_scalable (%arg0: vector <[32 ]xi8 >, %mask : vector <[32 ]xi1 >) -> i8 {
257+ %0 = vector.mask %mask { vector.reduction <minui >, %arg0 : vector <[32 ]xi8 > into i8 } : vector <[32 ]xi1 > -> i8
258+ return %0 : i8
259+ }
260+
261+ // CHECK-LABEL: func.func @masked_reduce_minui_i8_scalable(
262+ // CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
263+ // CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
264+ // CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8
265+ // CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
266+ // CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
267+ // CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
268+ // CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
269+ // CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
270+ // CHECK: "llvm.intr.vp.reduce.umin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
271+
272+ // -----
273+
200274func.func @masked_reduce_maxui_i8 (%arg0: vector <32 xi8 >, %mask : vector <32 xi1 >) -> i8 {
201275 %0 = vector.mask %mask { vector.reduction <maxui >, %arg0 : vector <32 xi8 > into i8 } : vector <32 xi1 > -> i8
202276 return %0 : i8
@@ -239,6 +313,24 @@ func.func @masked_reduce_maxsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -
239313
240314// -----
241315
316+ func.func @masked_reduce_maxsi_i8_scalable (%arg0: vector <[32 ]xi8 >, %mask : vector <[32 ]xi1 >) -> i8 {
317+ %0 = vector.mask %mask { vector.reduction <maxsi >, %arg0 : vector <[32 ]xi8 > into i8 } : vector <[32 ]xi1 > -> i8
318+ return %0 : i8
319+ }
320+
321+ // CHECK-LABEL: func.func @masked_reduce_maxsi_i8_scalable(
322+ // CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
323+ // CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
324+ // CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-128 : i8) : i8
325+ // CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
326+ // CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
327+ // CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
328+ // CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
329+ // CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
330+ // CHECK: "llvm.intr.vp.reduce.smax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
331+
332+ // -----
333+
242334func.func @masked_reduce_or_i8 (%arg0: vector <32 xi8 >, %mask : vector <32 xi1 >) -> i8 {
243335 %0 = vector.mask %mask { vector.reduction <or >, %arg0 : vector <32 xi8 > into i8 } : vector <32 xi1 > -> i8
244336 return %0 : i8
@@ -280,4 +372,22 @@ func.func @masked_reduce_xor_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) ->
280372// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
281373// CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
282374
375+ // -----
376+
377+ func.func @masked_reduce_xor_i8_scalable (%arg0: vector <[32 ]xi8 >, %mask : vector <[32 ]xi1 >) -> i8 {
378+ %0 = vector.mask %mask { vector.reduction <xor >, %arg0 : vector <[32 ]xi8 > into i8 } : vector <[32 ]xi1 > -> i8
379+ return %0 : i8
380+ }
381+
382+ // CHECK-LABEL: func.func @masked_reduce_xor_i8_scalable(
383+ // CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
384+ // CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
385+ // CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
386+ // CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
387+ // CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
388+ // CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
389+ // CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
390+ // CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
391+ // CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
392+
283393
0 commit comments