Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,32 @@ static bool isValidMMASchedule(const GPUMatmulShapeType &problem,
return isAligned && isDistributableLhs && isDistributableRhs;
}

/// Checks if the schedule's tile sizes satisfy LDS DMA alignment constraints.
/// Tile products must be multiples of (dmaSize / elementBitWidth) *
/// subgroup_size elements for fewer DMA transfers.
static bool isLDSDMAAligned(const GPUMMASchedule &schedule, int64_t lhsBitwidth,
int64_t rhsBitwidth, std::optional<int64_t> dmaSize,
int64_t subgroupSize) {
if (!dmaSize) {
return true;
}

int64_t mTile = schedule.getTotalMSize() * schedule.getTotalMTileSize();
int64_t nTile = schedule.getTotalNSize() * schedule.getTotalNTileSize();
int64_t kTile = schedule.getTotalKSize() * schedule.getTotalKTileSize();

int64_t lhsElements = mTile * kTile;
int64_t rhsElements = kTile * nTile;

int64_t lhsElementsPerTransfer = *dmaSize / lhsBitwidth * subgroupSize;
int64_t rhsElementsPerTransfer = *dmaSize / rhsBitwidth * subgroupSize;

bool isLhsAligned = (lhsElements % lhsElementsPerTransfer) == 0;
bool isRhsAligned = (rhsElements % rhsElementsPerTransfer) == 0;

return isLhsAligned && isRhsAligned;
}

/// Tries to fit the schedule into shared memory by decrementing the size of the
/// schedule dimensions from outermost to innermost until a valid schedule is
/// found. The schedule sizes are reduced in the order of mTileSizes,
Expand Down Expand Up @@ -666,7 +692,8 @@ FailureOr<GPUMMASchedule> deduceMMASchedule(
const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes,
int64_t subgroupSize, std::optional<int64_t> wgpCount, Location loc,
bool transposedLhs, bool transposedRhs, bool canUpcastAcc,
bool mustBeAligned, bool doCPromotion, int64_t splitReductionTripCnt) {
bool mustBeAligned, bool doCPromotion, int64_t splitReductionTripCnt,
std::optional<int64_t> dmaSize) {

SmallVector<GPUIntrinsicType> sortedIntrinsics =
sortMMAIntrinsics(problem, intrinsics);
Expand Down Expand Up @@ -697,9 +724,11 @@ FailureOr<GPUMMASchedule> deduceMMASchedule(
problem.aScaleType ? problem.aScaleType.getIntOrFloatBitWidth() : 0;
int64_t rhsScaleBitwidth =
problem.bScaleType ? problem.bScaleType.getIntOrFloatBitWidth() : 0;
bool isAligned =
bool isScheduleValid =
isValidMMASchedule(problem, schedule, mustBeAligned, subgroupSize,
transposedLhs, transposedRhs);
bool isDmaAligned = isLDSDMAAligned(schedule, lhsBitwidth, rhsBitwidth,
dmaSize, subgroupSize);
int64_t sharedMemoryUsed = calculateOperandsSharedMemoryUsedInBytes(
schedule, lhsBitwidth, rhsBitwidth, lhsScaleBitwidth,
rhsScaleBitwidth, problem.numHorizontallyFusedOps);
Expand All @@ -717,7 +746,8 @@ FailureOr<GPUMMASchedule> deduceMMASchedule(
<< "Predicted Shared Memory Used by Schedule: " << sharedMemoryUsed
<< " bytes";

bool isValid = isAligned && sharedMemoryUsed <= sharedMemLimitInBytes;
bool isValid = isScheduleValid && isDmaAligned &&
sharedMemoryUsed <= sharedMemLimitInBytes;
if (isValid) {
// Only emit remark for the shared memory usage of the valid schedule.
remark::analysis(loc, remark::RemarkOpts::name("SharedMemoryUsage")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,17 @@ struct GPUMMASchedule {
/// When |doCPromotion| is true, the accumulator uses shared memory. This can be
/// due to padding requirements or because the operation has an existing
/// accumulator that needs to be loaded from global memory (matmul_accumulate).
/// When |dmaSize| is provided, ensures tile products are multiples of
/// (dmaSize / elementBitWidth) * subgroup_size elements for fewer LDS DMA
/// transfers.
FailureOr<GPUMMASchedule> deduceMMASchedule(
const GPUMatmulShapeType &problem, ArrayRef<GPUIntrinsicType> intrinsics,
const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes,
int64_t subgroupSize, std::optional<int64_t> cuCount, Location loc,
bool transposedLhs = false, bool transposedRhs = false,
bool canUpcastAcc = false, bool mustBeAligned = true,
bool doCPromotion = false, int64_t splitReductionTripCnt = 0);
bool doCPromotion = false, int64_t splitReductionTripCnt = 0,
std::optional<int64_t> dmaSize = std::nullopt);

/// Returns a schedule for the pvMatmul in attention using one of the given MMA
/// |intrinsics| to target the given attention matmul problems, |qkMatmul|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Attributes.h"
Expand All @@ -45,6 +46,23 @@
constexpr int64_t kCacheLineSizeBits = 128 * 8;
constexpr int64_t kPreferredCopyNumBits = 128;

/// Returns true if the target supports global load DMA (LDS DMA) operations.
/// Only CDNA4+ (gfx950 and newer) architectures support this feature.
/// Excludes RDNA cards (gfx10xx, gfx11xx, gfx12xx) which have major version
/// >= 10.
static bool targetSupportsGlobalLoadDMA(IREE::GPU::TargetAttr target) {
StringRef targetArch = target.getArch();
auto maybeChipset = amdgpu::Chipset::parse(targetArch);
if (failed(maybeChipset)) {
return false;
}
// Only enable for CDNA4+ (gfx950+). Exclude RDNA cards (gfx10xx, gfx11xx,
// gfx12xx). CDNA cards have major version 9, RDNA cards have major version
// >= 10.
constexpr amdgpu::Chipset kGfx950{9, 5, 0};
return maybeChipset->majorVersion == 9 && *maybeChipset >= kGfx950;
}

Comment on lines +49 to +65
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is taken from #23230

//===----------------------------------------------------------------------===//
// Lowering Config Selection
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -436,15 +454,27 @@
wgpCount = chip.getWgpCount();
}

// To leverage LDS DMA with fewer transfers, tile size must be aligned to the
// largest DMA size.
std::optional<int64_t> dmaSize = std::nullopt;
if (targetSupportsGlobalLoadDMA(target)) {
DenseI64ArrayAttr dmaSizesAttr = target.getWgp().getDmaSizes();
if (dmaSizesAttr && !dmaSizesAttr.empty()) {
ArrayRef<int64_t> dmaSizes = dmaSizesAttr.asArrayRef();
dmaSize = *llvm::max_element(dmaSizes);
}
}

// First try to find a schedule with an exactly matching intrinsic.
std::optional<GPUMMASchedule> schedule = deduceMMASchedule(
problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize,
wgpCount, loc, transposedLhs, transposedRhs, /*canUpcastAcc=*/false,
/*mustBeAligned=*/mustBeAligned, doCPromotion, splitReductionTripCnt);
/*mustBeAligned=*/mustBeAligned, doCPromotion, splitReductionTripCnt,
dmaSize);
return schedule;
}

struct ConvToIgemmInfo {

Check warning on line 477 in compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

View workflow job for this annotation

GitHub Actions / clang-tidy

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp:477:8 [misc-use-internal-linkage]

struct 'ConvToIgemmInfo' can be moved into an anonymous namespace to enforce internal linkage
bool isBatchDimLast = false;
bool isSpatialDimLast = false;
linalg::ConvolutionDimensions convDims;
Expand Down
Loading