Skip to content

Commit ae9883c

Browse files
Fix for numerical issues in MatVec tests
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 4a66ccb commit ae9883c

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

mlir/lib/Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,26 +82,31 @@ Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
8282
}
8383

8484
if (ci.clusterSize >= 8) {
85-
Value dppResult = b.create<amdgpu::DPPOp>(
86-
loc, result.getType(), result, result, amdgpu::DPPPerm::row_half_mirror,
87-
b.getUnitAttr());
85+
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 4);
86+
Value dppResult =
87+
b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
88+
amdgpu::DPPPerm::row_shr, permArg);
8889
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
8990
result, dppResult);
9091
}
9192

9293
if (ci.clusterSize >= 16) {
94+
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 8);
9395
Value dppResult =
9496
b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
95-
amdgpu::DPPPerm::row_mirror, b.getUnitAttr());
97+
amdgpu::DPPPerm::row_shr, permArg);
9698
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
9799
result, dppResult);
98100
}
99101

102+
const int allRows = 0xf;
103+
const int allBanks = 0xf;
104+
100105
if (ci.clusterSize >= 32) {
101106
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 15);
102107
Value dppResult = b.create<amdgpu::DPPOp>(
103108
loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_15,
104-
b.getUnitAttr(), 10, 15, false);
109+
b.getUnitAttr(), 0xa, allBanks, false);
105110
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
106111
result, dppResult);
107112
}
@@ -110,7 +115,7 @@ Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
110115
auto permArg = b.getIntegerAttr(b.getIntegerType(32), 31);
111116
Value dppResult = b.create<amdgpu::DPPOp>(
112117
loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_31,
113-
b.getUnitAttr(), 12, 15, false);
118+
b.getUnitAttr(), allRows, allBanks, false);
114119
result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
115120
result, dppResult);
116121
}

0 commit comments

Comments
 (0)