Skip to content

Commit 9e9b500

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Fix the derivation for the number of warps for tiled HLO computations.
The number of warps used to process a computation determines how many registers we are able to use concurrently. Therefore, looking at the largest (padded) tile size makes sense, since it determines the minimum number of elements that must be live concurrently. Previously, the logic erroneously only looked at the output tile sizes. This approach is not perfect, and may be further improved by e.g. doing a live range analysis on the tiles of the computation. PiperOrigin-RevId: 680668856
1 parent 12d351d commit 9e9b500

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

xla/service/gpu/model/gpu_indexing_performance_model.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,9 +524,16 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton(
524524
LaunchDimensions
525525
GpuPerformanceModelWithIndexingAnalysis::GetLaunchDimensionsForTiledFusion(
526526
const TiledHloComputation& tiled_hlo_computation) {
527-
const auto* tiled_root = tiled_hlo_computation.GetRoot();
528527
int64_t num_blocks = tiled_hlo_computation.num_output_tiles();
529-
int64_t num_warps = GetNumWarps(GetPaddedTileSize(tiled_root->tile_sizes()));
528+
529+
// Decide on the number of warps to use based on the largest live tile size
530+
// at any given point within the computation.
531+
int64_t largest_live_tile_size = 1;
532+
for (const auto& tiled_hlo : tiled_hlo_computation.instructions()) {
533+
largest_live_tile_size = std::max(
534+
largest_live_tile_size, GetPaddedTileSize(tiled_hlo->tile_sizes()));
535+
}
536+
int64_t num_warps = GetNumWarps(largest_live_tile_size);
530537

531538
return {static_cast<uint64_t>(num_blocks),
532539
static_cast<uint64_t>(num_warps * WarpSize())};

xla/service/gpu/model/gpu_indexing_performance_model_test.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,54 @@ ENTRY main {
620620
// and corresponds to 4 warps.
621621
EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize());
622622
}
623+
624+
TEST_F(GpuIndexingPerformanceModelTest,
625+
NumberOfWarpsDependsOnLargestLiveTileSize) {
626+
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
627+
HloModule m
628+
629+
add {
630+
param_0 = f32[] parameter(0)
631+
param_1 = f32[] parameter(1)
632+
ROOT add = f32[] add(param_0, param_1)
633+
}
634+
635+
fusion_computation {
636+
param_0 = f32[1,4096] parameter(0)
637+
c0 = f32[] constant(0)
638+
ROOT reduce = f32[1] reduce(param_0, c0), dimensions={1}, to_apply=add
639+
}
640+
641+
ENTRY main {
642+
param_0 = f32[1,4096] parameter(0)
643+
ROOT fusion = f32[1] fusion(param_0), kind=kCustom,
644+
calls=fusion_computation,
645+
backend_config={"fusion_backend_config": {"kind":"__triton"}}
646+
}
647+
)"));
648+
auto fusion_adaptor = HloFusionAdaptor::ForInstruction(
649+
module->entry_computation()->root_instruction());
650+
651+
SymbolicTileAnalysisOrError analysis_or_error =
652+
SymbolicTileAnalysis::AnalyzeFusion(
653+
*fusion_adaptor, &mlir_context_,
654+
/*emitter_specific_constraints_builder=*/nullptr);
655+
ASSERT_TRUE(std::holds_alternative<SymbolicTileAnalysis>(analysis_or_error));
656+
657+
TF_ASSERT_OK_AND_ASSIGN(
658+
TiledHloComputation tiled_hlo_computation,
659+
std::get<SymbolicTileAnalysis>(analysis_or_error)
660+
.ComputeTiledHloInstructions(/*tile_parameters=*/{1}));
661+
662+
LaunchDimensions launch_dimensions = GpuPerformanceModelWithIndexingAnalysis::
663+
GetLaunchDimensionsForTiledFusion(tiled_hlo_computation);
664+
EXPECT_EQ(launch_dimensions.num_blocks(), 1);
665+
666+
// The largest tile size is 1 * 4096, for which our implementation recommends
667+
// using 4 warps.
668+
EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize());
669+
}
670+
623671
class FlopsPerElementTest : public GpuIndexingPerformanceModelTest {
624672
public:
625673
void CompareFlopsModels(absl::string_view hlo_module_string) {

0 commit comments

Comments
 (0)