@@ -2261,6 +2261,62 @@ struct MultiRotateCustomCallOptimize
22612261 }
22622262};
22632263
2264+ // / Detect whether this MultiSliceOp matches the cross-shard pattern:
2265+ // / 1. All strides are 1.
2266+ // / 2. For every sharded dimension except the multi-slice dimension,
2267+ // / start/limit span the full tensor extent.
2268+ // / 3. Along the multi-slice dimension, every slice's start falls within
2269+ // / one shard and its end falls within a different shard.
2270+ bool detectCrossShardPattern (Value operand, Operation *op,
2271+ ArrayRef<int64_t > startIndices,
2272+ ArrayRef<int64_t > limitIndices,
2273+ ArrayRef<int64_t > strides, int32_t dim,
2274+ int32_t amount, bool &needsSlice) {
2275+ // --- Condition 1: unit strides everywhere ---
2276+ if (!llvm::all_of (strides, [](int64_t s) { return s == 1 ; }))
2277+ return false ;
2278+
2279+ auto operandType = cast<RankedTensorType>(operand.getType ());
2280+ auto operandSharding = mlir::sdy::getSharding (operand);
2281+ if (!operandSharding) {
2282+ return false ;
2283+ }
2284+ ArrayRef<int64_t > shape = operandType.getShape ();
2285+ int64_t rank = shape.size ();
2286+
2287+ if (dim < 0 || dim >= rank)
2288+ return false ;
2289+
2290+ // --- Condition 2: full span on every sharded dim except `dim` ---
2291+ for (int64_t d = 0 ; d < rank; ++d) {
2292+ if (d == dim)
2293+ continue ;
2294+ int64_t numShards = getNumDevicesAlongDimension (operandSharding, d, op);
2295+ if (startIndices[d] != 0 || limitIndices[d] != shape[d]) {
2296+ needsSlice = true ;
2297+ if (numShards > 1 ) {
2298+ return false ;
2299+ }
2300+ }
2301+ }
2302+
2303+ // --- Condition 3: cross-shard slicing along `dim` ---
2304+ int64_t numShards = getNumDevicesAlongDimension (operandSharding, dim, op);
2305+ if (numShards <= 1 )
2306+ return false ; // Not sharded along the slice dimension.
2307+
2308+ int64_t dimSize = shape[dim];
2309+ int64_t shardSize = (dimSize + numShards - 1 ) / numShards;
2310+
2311+ if (startIndices[dim] > shardSize) {
2312+ return false ;
2313+ }
2314+ if (shape[dim] - limitIndices[dim] > shardSize) {
2315+ return false ;
2316+ }
2317+ return true ;
2318+ }
2319+
22642320struct MultiSliceCustomCallOptimize
22652321 : public OpRewritePattern<enzymexla::MultiSliceOp> {
22662322
@@ -2283,40 +2339,123 @@ struct MultiSliceCustomCallOptimize
22832339 if (slice->getParentOfType <sdy::ManualComputationOp>())
22842340 return failure ();
22852341
2286- auto rotateDimension = slice.getDimension ();
2342+ auto sliceDimension = slice.getDimension ();
22872343 auto shardings = mlir::sdy::getShardingPerValue (slice);
22882344 if (!shardings)
22892345 return rewriter.notifyMatchFailure (slice, " No sharding found." );
2290- auto rotateSharding = shardings.getSharding (0 );
2346+ auto sliceSharding = shardings.getSharding (0 );
2347+ for (int64_t i = 1 ; i < slice.getNumResults (); ++i) {
2348+ if (shardings.getSharding (i) != sliceSharding)
2349+ return rewriter.notifyMatchFailure (
2350+ slice, " Not all results have the same sharding" );
2351+ }
22912352
22922353 int64_t numDevicesAlongDimension =
2293- getNumDevicesAlongDimension (rotateSharding, rotateDimension , slice);
2354+ getNumDevicesAlongDimension (sliceSharding, sliceDimension , slice);
22942355
22952356 if (numDevicesAlongDimension == 1 ) {
22962357 return rewriter.notifyMatchFailure (
22972358 slice,
22982359 " numDevicesAlongDimension == 1. Communication is already optimized." );
22992360 }
23002361
2301- std::string start_indices =
2302- serializeDenseI64ArrayAttr (slice.getStartIndices ());
2303- std::string limit_indices =
2304- serializeDenseI64ArrayAttr (slice.getLimitIndices ());
2305- std::string strides = serializeDenseI64ArrayAttr (slice.getStrides ());
2362+ Value customCallOperand = slice.getOperand ();
2363+ auto operandSharding = mlir::sdy::getSharding (customCallOperand);
2364+ if (!operandSharding) {
2365+ return rewriter.notifyMatchFailure (slice, " No operand shardings" );
2366+ }
2367+ if (sliceSharding != operandSharding) {
2368+ return rewriter.notifyMatchFailure (slice,
2369+ " Mismatched input/output sharding" );
2370+ }
23062371
2307- std::string opaque = " dimension=" + std::to_string (rotateDimension) +
2372+ // Only lower to custom call if the cross-shard pattern is detected.
2373+ auto startIndices = SmallVector<int64_t >(slice.getStartIndices ());
2374+ auto limitIndices = SmallVector<int64_t >(slice.getLimitIndices ());
2375+ auto strideVals = SmallVector<int64_t >(slice.getStrides ());
2376+ bool needs_slice = false ;
2377+ if (!detectCrossShardPattern (customCallOperand, slice, startIndices,
2378+ limitIndices, strideVals, sliceDimension,
2379+ slice.getAmount (), needs_slice))
2380+ return rewriter.notifyMatchFailure (
2381+ slice, " MultiSlice does not match cross-shard pattern." );
2382+
2383+ // --- Replace the needs_slice bail-out and custom-call emission with this:
2384+ // ---
2385+
2386+ SmallVector<int64_t > finalStartIndices (startIndices);
2387+ SmallVector<int64_t > finalLimitIndices (limitIndices);
2388+ SmallVector<int64_t > finalStrides (strideVals);
2389+
2390+ if (needs_slice) {
2391+ // Emit a preliminary stablehlo::SliceOp that trims replicated
2392+ // (unsharded) dimensions down to the requested range, so that
2393+ // the MultiSlice custom call afterwards spans the full axis on
2394+ // every dimension except `dim`.
2395+ auto operandType = cast<RankedTensorType>(customCallOperand.getType ());
2396+ ArrayRef<int64_t > shape = operandType.getShape ();
2397+ int64_t rank = shape.size ();
2398+
2399+ auto operandSharding = sdy::getSharding (slice.getOperand ());
2400+
2401+ SmallVector<int64_t > preStart (rank);
2402+ SmallVector<int64_t > preLimit (rank);
2403+ SmallVector<int64_t > preStrides (rank, 1 );
2404+
2405+ for (int64_t d = 0 ; d < rank; ++d) {
2406+ if (d == sliceDimension) {
2407+ // Keep the full extent along the multi-slice dimension;
2408+ // the custom call handles cross-shard slicing there.
2409+ preStart[d] = 0 ;
2410+ preLimit[d] = shape[d];
2411+ } else {
2412+ int64_t numShards =
2413+ getNumDevicesAlongDimension (operandSharding, d, slice);
2414+ if (numShards <= 1 &&
2415+ (startIndices[d] != 0 || limitIndices[d] != shape[d])) {
2416+ // Replicated dim that doesn't span the full tensor —
2417+ // slice it now so the custom call can assume full extent.
2418+ preStart[d] = startIndices[d];
2419+ preLimit[d] = limitIndices[d];
2420+ // After pre-slicing, the custom call sees [0, newSize).
2421+ finalStartIndices[d] = 0 ;
2422+ finalLimitIndices[d] = limitIndices[d] - startIndices[d];
2423+ } else {
2424+ preStart[d] = 0 ;
2425+ preLimit[d] = shape[d];
2426+ }
2427+ }
2428+ }
2429+
2430+ auto preSliceOp = rewriter.create <stablehlo::SliceOp>(
2431+ slice.getLoc (), customCallOperand, preStart, preLimit, preStrides);
2432+
2433+ SmallVector<TensorShardingAttr> opShardings (1 , sliceSharding);
2434+ sdy::setShardings (preSliceOp, TensorShardingPerValueAttr::get (
2435+ rewriter.getContext (), opShardings));
2436+
2437+ customCallOperand = preSliceOp.getResult ();
2438+ }
2439+
2440+ std::string start_indices_str =
2441+ serializeDenseI64ArrayAttr (finalStartIndices);
2442+ std::string limit_indices_str =
2443+ serializeDenseI64ArrayAttr (finalLimitIndices);
2444+ std::string strides_str = serializeDenseI64ArrayAttr (finalStrides);
2445+
2446+ std::string opaque = " dimension=" + std::to_string (sliceDimension) +
23082447 " ,amount=" + std::to_string (slice.getAmount ()) +
2309- " ,start_indices=" + start_indices +
2310- " ,limit_indices=" + limit_indices +
2311- " ,strides=" + strides ;
2448+ " ,start_indices=" + start_indices_str +
2449+ " ,limit_indices=" + limit_indices_str +
2450+ " ,strides=" + strides_str ;
23122451
2313- auto fnSym = rewriter.getStringAttr (" _SPMDEnzymeInternalOp_MultiSlice " );
2452+ auto fnSym = rewriter.getStringAttr (" _SPMDInternalOp_MultiSlice " );
23142453
2315- SmallVector<TensorShardingAttr> opShardings (slice.getNumResults () ,
2316- rotateSharding );
2454+ SmallVector<TensorShardingAttr> opShardings (slice.getAmount () + 1 ,
2455+ sliceSharding );
23172456
23182457 auto ccall = rewriter.replaceOpWithNewOp <stablehlo::CustomCallOp>(
2319- slice, slice->getResultTypes (), slice-> getOperands () , fnSym,
2458+ slice, slice->getResultTypes (), ValueRange{customCallOperand} , fnSym,
23202459 /* has_side_effect=*/ rewriter.getBoolAttr (false ),
23212460 /* backend_config=*/ rewriter.getStringAttr (opaque),
23222461 /* api_version=*/ nullptr ,
0 commit comments