11#include " mlir/Dialect/SCF/IR/SCF.h"
2+ #include " mlir/IR/Attributes.h"
3+ #include " mlir/IR/BuiltinAttributes.h"
24#include " mlir/IR/BuiltinOps.h"
35#include " mlir/Pass/Pass.h"
46#include " nvidia/include/Dialect/NVWS/IR/Dialect.h"
7+ #include " triton/Dialect/Triton/IR/Dialect.h"
58#include " triton/Dialect/TritonGPU/IR/Dialect.h"
69#include " triton/Dialect/TritonGPU/Transforms/Partition.h"
710#include " triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h"
@@ -202,17 +205,12 @@ LogicalResult DependencyRewriter::run() {
202205 llvm::zip (schedule.getPartitions (), partitionUseInfo)) {
203206 // The amount of buffering is based on the longest distance to a user.
204207 for (auto &[output, info] : useInfo) {
205- // FIXME: No IR support for passing simple scalars through shared
206- // memory.
207- auto tensorType = dyn_cast<RankedTensorType>(output.getType ());
208- if (!tensorType) {
209- return mlir::emitWarning (output.getLoc (),
210- " FIXME: only tensor SSA dependencies between "
211- " partitions are supported" );
212- }
208+ b.setLoc (output.getLoc ());
209+ ImplicitLocOpBuilder endBuilder (b.getLoc (), loop->getNextNode ());
213210
214- Operation *defOp ;
211+ bool isScalar = false ;
215212 Value tmp = output;
213+ Operation *defOp;
216214 while (true ) {
217215 if (auto arg = dyn_cast<BlockArgument>(tmp)) {
218216 tmp = loop.getBody ()->getTerminator ()->getOperand (arg.getArgNumber () -
@@ -222,14 +220,31 @@ LogicalResult DependencyRewriter::run() {
222220 defOp = tmp.getDefiningOp ();
223221 break ;
224222 }
223+ Value val = output;
224+ auto tensorType = dyn_cast<RankedTensorType>(output.getType ());
225+ if (!tensorType) {
226+ isScalar = true ;
227+ b.setInsertionPointAfterValue (output);
228+ auto mod = output.getParentRegion ()->getParentOfType <ModuleOp>();
229+ auto nWarps = lookupNumWarps (mod);
230+ auto threadsPerWarp =
231+ triton::gpu::TritonGPUDialect::getThreadsPerWarp (mod);
232+ int CTAs = triton::gpu::TritonGPUDialect::getNumCTAs (mod);
233+ Attribute encoding = getDefaultBlockedEncoding (
234+ b.getContext (), {1 }, nWarps, threadsPerWarp, CTAs);
235+ tensorType = RankedTensorType::get ({1 }, output.getType (), encoding);
236+ StageCluster srcStageCluster = getStageCluster (defOp);
237+
238+ defOp = b.createInto <triton::SplatOp>(partition, srcStageCluster,
239+ tensorType, output);
240+ val = defOp->getResult (0 );
241+ }
225242
226243 // Buffer the value based on the greatest distance to a consumer
227244 // partition.
228245 int maxDistance = info.getMaxUseDistance (partition);
229246
230247 // Allocate buffers for the value and its associated barriers.
231- b.setLoc (output.getLoc ());
232- ImplicitLocOpBuilder endBuilder (b.getLoc (), loop->getNextNode ());
233248 AsyncRef aref = allocateAsyncValue (tensorType, maxDistance);
234249
235250 unsigned numConsumers = info.consumers .size ();
@@ -249,20 +264,24 @@ LogicalResult DependencyRewriter::run() {
249264 // partition with it.
250265 Value value = b.createInto <LocalLoadOp>(*usePartition, sinkSrcCluster,
251266 tensorType, view);
267+ if (isScalar) {
268+ value = b.createInto <triton::UnsplatOp>(*usePartition, sinkSrcCluster,
269+ value);
270+ }
252271 for (OpOperand *use : uses)
253272 use->set (value);
254273 exitOp (b);
255274 }
256275
257276 // Set up production of the value
258- if (isa<BlockArgument>(output ))
277+ if (isa<BlockArgument>(val ))
259278 b.setInsertionPointToStart (loop.getBody ());
260279 else
261280 b.setInsertionPointAfter (defOp);
262281
263282 StageCluster srcStageCluster = getStageCluster (defOp);
264283 auto [view, exitOp] = aref.putView (b, partition, srcStageCluster);
265- b.createInto <LocalStoreOp>(partition, srcStageCluster, output , view);
284+ b.createInto <LocalStoreOp>(partition, srcStageCluster, val , view);
266285 exitOp (b);
267286 }
268287 }
0 commit comments