Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
.Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
.Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
.Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
.Case([](arith::XOrIOp) { return arith::AtomicRMWKind::xori; })
.Case([](arith::MaxNumFOp) { return arith::AtomicRMWKind::maxnumf; })
.Case([](arith::MinNumFOp) { return arith::AtomicRMWKind::minnumf; })
.Default([](Operation *) -> std::optional<arith::AtomicRMWKind> {
// TODO: AtomicRMW supports other kinds of reductions this is
// currently not detecting, add those when the need arises.
return std::nullopt;
});
if (!maybeKind)
Expand Down
10 changes: 9 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,15 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
case arith::AtomicRMWKind::ori:
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::OR, vector);
// TODO: Add remaining reduction operations.
case arith::AtomicRMWKind::minnumf:
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::MINNUMF, vector);
case arith::AtomicRMWKind::maxnumf:
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::MAXNUMF, vector);
case arith::AtomicRMWKind::xori:
return vector::ReductionOp::create(builder, vector.getLoc(),
CombiningKind::XOR, vector);
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
break;
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Conversion/ConvertToSPIRV/vector.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,42 @@ func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {

// -----

// CHECK-LABEL: spirv.func @reduction_minnumf(
// CHECK-SAME: %[[V:.*]]: vector<3xf32>,
// CHECK-SAME: %[[S:.*]]: f32) -> f32 "None" {
// CHECK: %[[S0:.*]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
// CHECK: %[[S1:.*]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.*]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MIN0:.*]] = spirv.GL.FMin %[[S0]], %[[S1]] : f32
// CHECK: %[[MIN1:.*]] = spirv.GL.FMin %[[MIN0]], %[[S2]] : f32
// CHECK: %[[MIN2:.*]] = spirv.GL.FMin %[[MIN1]], %[[S]] : f32
// CHECK: spirv.ReturnValue %[[MIN2]] : f32
// CHECK: }
func.func @reduction_minnumf(%v : vector<3xf32>, %s: f32) -> f32 {
%reduce = vector.reduction <minnumf>, %v, %s : vector<3xf32> into f32
return %reduce : f32
}

// -----

// CHECK-LABEL: spirv.func @reduction_maxnumf(
// CHECK-SAME: %[[V:.*]]: vector<3xf32>,
// CHECK-SAME: %[[S:.*]]: f32) -> f32 "None" {
// CHECK: %[[S0:.*]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
// CHECK: %[[S1:.*]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
// CHECK: %[[S2:.*]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
// CHECK: %[[MAX0:.*]] = spirv.GL.FMax %[[S0]], %[[S1]] : f32
// CHECK: %[[MAX1:.*]] = spirv.GL.FMax %[[MAX0]], %[[S2]] : f32
// CHECK: %[[MAX2:.*]] = spirv.GL.FMax %[[MAX1]], %[[S]] : f32
// CHECK: spirv.ReturnValue %[[MAX2]] : f32
// CHECK: }
func.func @reduction_maxnumf(%v : vector<3xf32>, %s: f32) -> f32 {
%reduce = vector.reduction <maxnumf>, %v, %s : vector<3xf32> into f32
return %reduce : f32
}

// -----

// CHECK-LABEL: func @reduction_maxsi
// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>
Expand Down
100 changes: 100 additions & 0 deletions mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,106 @@ func.func @vecdim_reduction_ori(%in: memref<256x512xi32>, %out: memref<256xi32>)
// CHECK: affine.store %[[final_red]], %{{.*}} : memref<256xi32>
// CHECK: }

// -----

func.func @vecdim_reduction_xori(%in: memref<256x512xi32>, %out: memref<256xi32>) {
%cst = arith.constant 0 : i32
affine.for %i = 0 to 256 {
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (i32) {
%ld = affine.load %in[%i, %j] : memref<256x512xi32>
%xor = arith.xori %red_iter, %ld : i32
affine.yield %xor : i32
}
affine.store %final_red, %out[%i] : memref<256xi32>
}
return
}

// CHECK-LABEL: func.func @vecdim_reduction_xori(
// CHECK-SAME: %[[input:.*]]: memref<256x512xi32>,
// CHECK-SAME: %[[output:.*]]: memref<256xi32>) {
// CHECK: %[[cst:.*]] = arith.constant 0 : i32
// CHECK: affine.for %{{.*}} = 0 to 256 {
// CHECK: %[[vzero:.*]] = arith.constant dense<0> : vector<128xi32>
// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xi32>) {
// CHECK: %[[poison:.*]] = ub.poison : i32
// CHECK: %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xi32>, vector<128xi32>
// CHECK: %[[xor:.*]] = arith.xori %[[red_iter]], %[[ld]] : vector<128xi32>
// CHECK: affine.yield %[[xor]] : vector<128xi32>
// CHECK: }
// CHECK: %[[final_red:.*]] = vector.reduction <xor>, %[[vred]] : vector<128xi32> into i32
// CHECK: affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xi32>
// CHECK: }
// CHECK: return
// CHECK: }

// -----

func.func @vecdim_reduction_minnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
%cst = arith.constant 0xFF800000 : f32
affine.for %i = 0 to 256 {
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
%ld = affine.load %in[%i, %j] : memref<256x512xf32>
%min = arith.minnumf %red_iter, %ld : f32
affine.yield %min : f32
}
affine.store %final_red, %out[%i] : memref<256xf32>
}
return
}

// CHECK-LABEL: func.func @vecdim_reduction_minnumf(
// CHECK-SAME: %[[input:.*]]: memref<256x512xf32>,
// CHECK-SAME: %[[output:.*]]: memref<256xf32>) {
// CHECK: %[[cst:.*]] = arith.constant 0xFF800000 : f32
// CHECK: affine.for %{{.*}} = 0 to 256 {
// CHECK: %[[vzero:.*]] = arith.constant dense<0x7FC00000> : vector<128xf32>
// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) {
// CHECK: %[[poison:.*]] = ub.poison : f32
// CHECK: %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xf32>, vector<128xf32>
// CHECK: %[[min:.*]] = arith.minnumf %[[red_iter]], %[[ld]] : vector<128xf32>
// CHECK: affine.yield %[[min]] : vector<128xf32>
// CHECK: }
// CHECK: %[[red_scalar:.*]] = vector.reduction <minnumf>, %[[vred]] : vector<128xf32> into f32
// CHECK: %[[final_red:.*]] = arith.minnumf %[[red_scalar]], %[[cst]] : f32
// CHECK: affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xf32>
// CHECK: }
// CHECK: return
// CHECK: }

// -----

func.func @vecdim_reduction_maxnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
%cst = arith.constant 0xFF800000 : f32
affine.for %i = 0 to 256 {
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
%ld = affine.load %in[%i, %j] : memref<256x512xf32>
%max = arith.maxnumf %red_iter, %ld : f32
affine.yield %max : f32
}
affine.store %final_red, %out[%i] : memref<256xf32>
}
return
}

// CHECK-LABEL: func.func @vecdim_reduction_maxnumf(
// CHECK-SAME: %[[input:.*]]: memref<256x512xf32>,
// CHECK-SAME: %[[output:.*]]: memref<256xf32>) {
// CHECK: %[[cst:.*]] = arith.constant 0xFF800000 : f32
// CHECK: affine.for %{{.*}} = 0 to 256 {
// CHECK: %[[vzero:.*]] = arith.constant dense<0xFFC00000> : vector<128xf32>
// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) {
// CHECK: %[[poison:.*]] = ub.poison : f32
// CHECK: %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xf32>, vector<128xf32>
// CHECK: %[[max:.*]] = arith.maxnumf %[[red_iter]], %[[ld]] : vector<128xf32>
// CHECK: affine.yield %[[max]] : vector<128xf32>
// CHECK: }
// CHECK: %[[red_scalar:.*]] = vector.reduction <maxnumf>, %[[vred]] : vector<128xf32> into f32
// CHECK: %[[final_red:.*]] = arith.maxnumf %[[red_scalar]], %[[cst]] : f32
// CHECK: affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xf32>
// CHECK: }
// CHECK: return
// CHECK: }

// -----

Expand Down