Skip to content
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,8 @@ def GPU_ShuffleOp : GPU_Op<
%3, %4 = gpu.shuffle down %0, %cst1, %width : f32
```

For lane `k`, returns the value from lane `(k + 1) % width`.
For lane `k`, returns the value from lane `(k + cst1)`. If `(k + cst1)` is
bigger than or equal to `width`, the value is poison and `valid` is `false`.

`up` example:

Expand All @@ -1341,7 +1342,8 @@ def GPU_ShuffleOp : GPU_Op<
%5, %6 = gpu.shuffle up %0, %cst1, %width : f32
```

For lane `k`, returns the value from lane `(k - 1) % width`.
For lane `k`, returns the value from lane `(k - cst1)`. If `(k - cst1)` is
smaller than `0`, the value is poison and `valid` is `false`.

`idx` example:

Expand Down
47 changes: 40 additions & 7 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,26 +435,59 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
return rewriter.notifyMatchFailure(
shuffleOp, "shuffle width and target subgroup size mismatch");

// Ensure the offset is a signless/unsigned integer.
if (adaptor.getOffset().getType().isSignedInteger())
return rewriter.notifyMatchFailure(
shuffleOp, "shuffle offset must be a signless/unsigned integer");

Location loc = shuffleOp.getLoc();
Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
Value result;
Value validVal;

switch (shuffleOp.getMode()) {
case gpu::ShuffleMode::XOR:
case gpu::ShuffleMode::XOR: {
result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
case gpu::ShuffleMode::IDX:
}
case gpu::ShuffleMode::IDX: {
result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
}
case gpu::ShuffleMode::DOWN: {
result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());

Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
Value resultLaneId =
rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
resultLaneId, adaptor.getWidth());
break;
}
case gpu::ShuffleMode::UP: {
result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());

Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
Value resultLaneId =
rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
auto i32Type = rewriter.getIntegerType(32);
validVal = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, resultLaneId,
rewriter.create<arith::ConstantOp>(
loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)));
break;
default:
return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
}
}

rewriter.replaceOp(shuffleOp, {result, trueVal});
rewriter.replaceOp(shuffleOp, {result, validVal});
return success();
}

Expand Down
71 changes: 69 additions & 2 deletions mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ gpu.module @kernels {

// CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.Constant true
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleXor <Subgroup> %[[VAL]], %[[MASK]] : f32, i32
// CHECK: %{{.+}} = spirv.Constant true
%result, %valid = gpu.shuffle xor %val, %mask, %width : f32
gpu.return
}
Expand Down Expand Up @@ -64,11 +64,78 @@ gpu.module @kernels {

// CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.Constant true
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffle <Subgroup> %[[VAL]], %[[MASK]] : f32, i32
// CHECK: %{{.+}} = spirv.Constant true
%result, %valid = gpu.shuffle idx %val, %mask, %width : f32
gpu.return
}
}

}

// -----

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {

gpu.module @kernels {
// CHECK-LABEL: spirv.func @shuffle_down()
gpu.func @shuffle_down() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
%val = arith.constant 42.0 : f32

// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleDown <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32

// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr<i32, Input>
// CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.IAdd %[[LANE_ID]], %[[OFFSET]] : i32
// CHECK: %[[VALID:.+]] = spirv.ULessThan %[[VAL_LANE_ID]], %[[WIDTH]] : i32

%result, %valid = gpu.shuffle down %val, %offset, %width : f32
gpu.return
}
}

}

// -----

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {

gpu.module @kernels {
// CHECK-LABEL: spirv.func @shuffle_up()
gpu.func @shuffle_up() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
%val = arith.constant 42.0 : f32

// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleUp <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32

// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr<i32, Input>
// CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.ISub %[[LANE_ID]], %[[OFFSET]] : i32
// CHECK: %[[CST0:.+]] = spirv.Constant 0 : i32
// CHECK: %[[VALID:.+]] = spirv.SGreaterThanEqual %[[VAL_LANE_ID]], %[[CST0]] : i32

%result, %valid = gpu.shuffle up %val, %offset, %width : f32
gpu.return
}
}

}
Loading