1111// ===----------------------------------------------------------------------===//
1212
1313#include " mlir/Dialect/Arith/IR/Arith.h"
14- #include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
1514#include " mlir/Dialect/GPU/IR/GPUDialect.h"
1615#include " mlir/Dialect/GPU/Transforms/Passes.h"
1716#include " mlir/Dialect/GPU/Utils/GPUUtils.h"
1817#include " mlir/Dialect/Vector/IR/VectorOps.h"
19- #include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
2018#include " mlir/IR/BuiltinTypes.h"
2119#include " mlir/IR/Location.h"
2220#include " mlir/IR/PatternMatch.h"
2624#include < cassert>
2725#include < cstdint>
2826
29- #define DPP
30-
3127using namespace mlir ;
3228
3329namespace {
@@ -192,8 +188,6 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
192188 function_ref<Value(Value)> unpackFn) {
193189 // Lane value always stays in the original type. We use it to perform arith
194190 // reductions.
195- llvm::errs () << " Cluster Stride: " << ci.clusterStride << " \n " ;
196- llvm::errs () << " Cluster Size: " << ci.clusterSize << " \n " ;
197191 Value laneVal = input;
198192 // Parallel reduction using butterfly shuffles.
199193 for (unsigned i = ci.clusterStride ; i < ci.clusterStride * ci.clusterSize ;
@@ -212,146 +206,6 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
212206 return laneVal;
213207}
214208
215- #ifdef DPP
216- Value createSubgroupDPPReduction (OpBuilder &b, Location loc,
217- Value input, gpu::AllReduceOperation mode,
218- const ClusterInfo &ci,
219- function_ref<Value(Value)> packFn,
220- function_ref<Value(Value)> unpackFn) {
221- llvm::errs () << " createSubgroupDPPReduction" << " \n " ;
222- Value result = input;
223- if (ci.clusterSize >= 2 ) {
224- auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 1 );
225- Value dppResult = b.create <amdgpu::DPPOp>(loc, result.getType (), result, result, amdgpu::DPPPerm::row_shr, permArg);
226- llvm::errs () << dppResult << " c 2 \n " ;
227- result = vector::makeArithReduction (b, loc,
228- gpu::convertReductionKind (mode),
229- result, dppResult);
230- }
231-
232- if (ci.clusterSize >= 4 ) {
233- auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 2 );
234- Value dppResult = b.create <amdgpu::DPPOp>(loc, result.getType (), result, result, amdgpu::DPPPerm::row_shr, permArg);
235- llvm::errs () << dppResult << " c 4 \n " ;
236- result = vector::makeArithReduction (b, loc,
237- gpu::convertReductionKind (mode),
238- result, dppResult);
239- }
240-
241- if (ci.clusterSize >= 8 ) {
242-
243- Value dppResult = b.create <amdgpu::DPPOp>(loc, result.getType (), result, result, amdgpu::DPPPerm::row_half_mirror, b.getUnitAttr ());
244- llvm::errs () << dppResult << " c 8 \n " ;
245- result = vector::makeArithReduction (b, loc,
246- gpu::convertReductionKind (mode),
247- result, dppResult);
248- }
249-
250- if (ci.clusterSize >= 16 ) {
251- Value dppResult = b.create <amdgpu::DPPOp>(loc, result.getType (), result, result, amdgpu::DPPPerm::row_mirror, b.getUnitAttr ());
252- llvm::errs () << dppResult << " c 16 \n " ;
253- result = vector::makeArithReduction (b, loc,
254- gpu::convertReductionKind (mode),
255- result, dppResult);
256- }
257-
258- if (ci.clusterSize >= 32 ) {
259- // auto permArg = builder.getInt32(15);
260- // auto rowMask = builder.getInt32("0xa");
261- // auto bankMask = builder.getInt32("0xf");
262- // auto boundCtrl = builder.getBoolAttr(false);
263- auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 15 );
264- Value dppResult = b.create <amdgpu::DPPOp>(loc, result.getType (), result, result, amdgpu::DPPPerm::row_bcast_15, b.getUnitAttr (), 10 , 15 , false );
265- llvm::errs () << dppResult << " c 32 \n " ;
266- result = vector::makeArithReduction (b, loc,
267- gpu::convertReductionKind (mode),
268- result, dppResult);
269- }
270-
271- if (ci.clusterSize == 64 ) {
272- // auto permArg = builder.getInt32(31);
273- // auto rowMask = builder.getInt32("0xc");
274- // auto bankMask = builder.getInt32("0xf");
275- // auto boundCtrl = builder.getBoolAttr(false);
276- auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 31 );
277- Value dppResult = b.create <amdgpu::DPPOp>(loc, result.getType (), result, result, amdgpu::DPPPerm::row_bcast_31, b.getUnitAttr (), 12 , 15 , false );
278- llvm::errs () << dppResult << " c 64 \n " ;
279- result = vector::makeArithReduction (b, loc,
280- gpu::convertReductionKind (mode),
281- result, dppResult);
282- }
283-
284- // // read lane 63 with the final result.
285- // auto lane = b.getIntegerAttr(b.getIntegerType(32), 63);
286- // result = b.create<ROCDL::ReadLaneOp>(loc, input.getType(), result, lane);
287- assert (result.getType () == input.getType ());
288- return result;
289- }
290- #endif
291-
292- // Value createSubgroupDPPReduction(OpBuilder &b, Location loc,
293- // Value input, gpu::AllReduceOperation mode,
294- // const ClusterInfo &ci,
295- // function_ref<Value(Value)> packFn,
296- // function_ref<Value(Value)> unpackFn) {
297-
298- // Value result = input;
299- // if (ci.clusterSize >= 2) {
300- // auto permArg = b.getInt32(1);
301- // Value dppResult = builder.create<amdgpu::DPPOp>(packFn(result), packFn(result), amdgpu::DPPPerm::row_shr, permArg);
302- // result = vector::makeArithReduction(builder, loc,
303- // gpu::convertReductionKind(mode),
304- // result, unpackFn(dppResult));
305- // }
306-
307- // if (ci.clusterSize >= 4) {
308- // auto permArg = builder.getInt32(2);
309- // Value dppResult = builder.create<amdgpu::DPPOp>(packFn(result), packFn(result), amdgpu::DPPPerm::row_shr, permArg);
310- // result = vector::makeArithReduction(builder, loc,
311- // gpu::convertReductionKind(mode),
312- // result, unpackFn(dppResult));
313- // }
314-
315- // if (ci.clusterSize >= 8) {
316- // Value dppResult = builder.create<amdgpu::DPPOp>(packFn(result), packFn(result), amdgpu::DPPPerm::row_half_mirror);
317- // result = vector::makeArithReduction(builder, loc,
318- // gpu::convertReductionKind(mode),
319- // result, unpackFn(dppResult));
320- // }
321-
322- // if (ci.clusterSize >= 16) {
323- // Value dppResult = builder.create<amdgpu::DPPOp>(packFn(result), packFn(result), amdgpu::DPPPerm::row_mirror);
324- // result = vector::makeArithReduction(builder, loc,
325- // gpu::convertReductionKind(mode),
326- // result, unpackFn(dppResult));
327- // }
328-
329- // if (ci.clusterSize >= 32) {
330- // auto permArg = builder.getInt32(15);
331- // auto rowMask = builder.getInt32("0xa");
332- // auto bankMask = builder.getInt32("0xf");
333- // auto boundCtrl = builder.getBoolAttr(false);
334- // Value dppResult = builder.create<amdgpu::DPPOp>(packFn(result), packFn(result), amdgpu::DPPPerm::row_bcast, permArg, rowMask, bankMask, boundCtrl);
335- // result = vector::makeArithReduction(builder, loc,
336- // gpu::convertReductionKind(mode),
337- // result, unpackFn(dppResult));
338- // }
339-
340- // if (ci.clusterSize == 64) {
341- // auto permArg = builder.getInt32(31);
342- // auto rowMask = builder.getInt32("0xc");
343- // auto bankMask = builder.getInt32("0xf");
344- // auto boundCtrl = builder.getBoolAttr(false);
345- // Value dppResult = builder.create<amdgpu::DPPOp>(packFn(result), packFn(result), amdgpu::DPPPerm::row_bcast, permArg, rowMask, bankMask, boundCtrl);
346- // result = vector::makeArithReduction(builder, loc,
347- // gpu::convertReductionKind(mode),
348- // result, unpackFn(dppResult));
349- // }
350-
351- // assert(result.getType() == input.getType());
352- // return result;
353- // }
354-
355209// / Lowers scalar gpu subgroup reductions to a series of shuffles.
356210struct ScalarSubgroupReduceToShuffles final
357211 : OpRewritePattern<gpu::SubgroupReduceOp> {
@@ -363,7 +217,6 @@ struct ScalarSubgroupReduceToShuffles final
363217
364218 LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
365219 PatternRewriter &rewriter) const override {
366- llvm::errs () << " ScalarSubgroupReduceToShuffles" << " \n " ;
367220 if (op.getClusterSize ().has_value () != matchClustered) {
368221 return rewriter.notifyMatchFailure (
369222 op, llvm::formatv (" op is {0}clustered but pattern is configured to "
@@ -386,17 +239,10 @@ struct ScalarSubgroupReduceToShuffles final
386239 Location loc = op.getLoc ();
387240 // Since this is already a native shuffle scalar, no packing is necessary.
388241 if (elemBitwidth == shuffleBitwidth) {
389- llvm::errs () << " ScalarSubgroupReduceToShuffles - 1" << " \n " ;
390242 auto identityFn = [](Value v) { return v; };
391- #ifndef DPP
392243 rewriter.replaceOp (op, createSubgroupShuffleReduction (
393244 rewriter, loc, op.getValue (), op.getOp (), *ci,
394245 identityFn, identityFn));
395- #else
396- rewriter.replaceOp (op, createSubgroupDPPReduction (
397- rewriter, loc, op.getValue (), op.getOp (), *ci,
398- identityFn, identityFn));
399- #endif
400246 return success ();
401247 }
402248
@@ -414,15 +260,10 @@ struct ScalarSubgroupReduceToShuffles final
414260 rewriter.create <arith::TruncIOp>(loc, equivIntType, packedVal);
415261 return rewriter.create <arith::BitcastOp>(loc, valueTy, asInt);
416262 };
417- llvm::errs () << " ScalarSubgroupReduceToShuffles - 2" << " \n " ;
418- #ifndef DPP
263+
419264 rewriter.replaceOp (
420265 op, createSubgroupShuffleReduction (rewriter, loc, op.getValue (),
421266 op.getOp (), *ci, packFn, unpackFn));
422- #else
423- rewriter.replaceOp (op, createSubgroupDPPReduction (rewriter, loc, op.getValue (),
424- op.getOp (), *ci, packFn, unpackFn));
425- #endif
426267 return success ();
427268 }
428269
@@ -443,7 +284,6 @@ struct VectorSubgroupReduceToShuffles final
443284
444285 LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
445286 PatternRewriter &rewriter) const override {
446- llvm::errs () << " VectorSubgroupReduceToShuffles" << " \n " ;
447287 if (op.getClusterSize ().has_value () != matchClustered) {
448288 return rewriter.notifyMatchFailure (
449289 op, llvm::formatv (" op is {0}clustered but pattern is configured to "
0 commit comments