Skip to content

Commit 89fc913

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[XLA] Implement a TiledHloSchedule that transposes the iteration order over the non-contracting dimensions of a dot.
A concrete use case when such a schedule is useful is when we have a matrix multiplication such that a chunk of shape `(block_m, k)` of the left-hand side argument fully fits into L2. The transposed iteration order will step through the `n` dimension first, allowing to hit L2 cache more often when loading tiles of the left-hand side. This schedule is intentionally restricted at the moment in order to unblock launching the generic Triton emitter for GEMMs. PiperOrigin-RevId: 820214481
1 parent 832c86a commit 89fc913

File tree

4 files changed

+486
-12
lines changed

4 files changed

+486
-12
lines changed

xla/codegen/tiling/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,18 @@ cc_library(
118118
srcs = ["tiled_hlo_schedule.cc"],
119119
hdrs = ["tiled_hlo_schedule.h"],
120120
deps = [
121+
":tiling_specification",
121122
"//xla:util",
122123
"//xla/hlo/analysis:indexing_analysis",
124+
"//xla/hlo/ir:hlo",
123125
"//xla/service/gpu/model/experimental:symbolic_expr",
126+
"//xla/tsl/platform:errors",
127+
"//xla/tsl/platform:statusor",
128+
"@com_google_absl//absl/algorithm:container",
124129
"@com_google_absl//absl/log:check",
125130
"@com_google_absl//absl/status",
126131
"@com_google_absl//absl/status:statusor",
132+
"@com_google_absl//absl/strings",
127133
"@com_google_absl//absl/strings:str_format",
128134
"@com_google_absl//absl/types:span",
129135
"@llvm-project//llvm:Support",
@@ -135,15 +141,23 @@ xla_cc_test(
135141
name = "tiled_hlo_schedule_test",
136142
srcs = ["tiled_hlo_schedule_test.cc"],
137143
deps = [
144+
":symbolic_tile_analysis",
138145
":tiled_hlo_schedule",
146+
":tiling_specification",
139147
"//xla/hlo/analysis:indexing_analysis",
140148
"//xla/hlo/analysis:interval",
149+
"//xla/hlo/ir:hlo",
141150
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
151+
"//xla/hlo/testlib:verified_hlo_module",
142152
"//xla/service/gpu/model/experimental:symbolic_expr",
153+
"//xla/tsl/lib/core:status_test_util",
143154
"//xla/tsl/platform:statusor",
155+
"@com_google_absl//absl/log:check",
144156
"@com_google_absl//absl/status",
145157
"@com_google_absl//absl/status:status_matchers",
158+
"@com_google_absl//absl/strings:string_view",
146159
"@com_google_googletest//:gtest_main",
160+
"@llvm-project//llvm:Support",
147161
"@llvm-project//mlir:IR",
148162
],
149163
)

xla/codegen/tiling/tiled_hlo_schedule.cc

Lines changed: 151 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,37 @@ limitations under the License.
1818
#include <cstdint>
1919
#include <vector>
2020

21+
#include "absl/algorithm/container.h"
2122
#include "absl/log/check.h"
2223
#include "absl/status/status.h"
2324
#include "absl/status/statusor.h"
25+
#include "absl/strings/str_cat.h"
2426
#include "absl/strings/str_format.h"
27+
#include "absl/strings/string_view.h"
28+
#include "absl/types/span.h"
2529
#include "llvm/ADT/STLExtras.h"
2630
#include "mlir/IR/AffineExpr.h"
2731
#include "mlir/IR/AffineMap.h"
2832
#include "mlir/IR/MLIRContext.h"
33+
#include "xla/codegen/tiling/tiling_specification.h"
2934
#include "xla/hlo/analysis/indexing_analysis.h"
3035
#include "xla/hlo/analysis/indexing_map.h"
36+
#include "xla/hlo/ir/hlo_casting_utils.h"
37+
#include "xla/hlo/ir/hlo_instructions.h"
38+
#include "xla/hlo/ir/hlo_opcode.h"
3139
#include "xla/service/gpu/model/experimental/symbolic_expr.h"
40+
#include "xla/tsl/platform/errors.h"
41+
#include "xla/tsl/platform/statusor.h"
3242
#include "xla/util.h"
3343

3444
namespace xla {
3545

36-
absl::StatusOr<IndexingMap> MajorToMinorTiledHloSchedule::Schedule(
37-
const IndexingMap& tile_offsets_indexing, IterationSpace iteration_space,
38-
gpu::SymbolicExprContext* symbolic_expr_context) const {
39-
mlir::MLIRContext* mlir_context = symbolic_expr_context->GetMLIRContext();
46+
namespace {
47+
48+
// Helper to validate that an iteration space is compatible with a tile offsets
49+
// indexing map.
50+
absl::Status ValidateIterationSpace(const IterationSpace& iteration_space,
51+
const IndexingMap& tile_offsets_indexing) {
4052
if (iteration_space.size() != tile_offsets_indexing.GetDimVarsCount()) {
4153
return absl::InvalidArgumentError(absl::StrFormat(
4254
"Expected iteration space to have exactly as many dimensions as there "
@@ -45,6 +57,30 @@ absl::StatusOr<IndexingMap> MajorToMinorTiledHloSchedule::Schedule(
4557
iteration_space.size(), tile_offsets_indexing.GetDimVarsCount()));
4658
}
4759

60+
std::vector<int64_t> iteration_space_dims;
61+
iteration_space_dims.reserve(iteration_space.size());
62+
63+
for (const auto& [dim_id, dim_size] : iteration_space) {
64+
if (dim_id >= tile_offsets_indexing.GetDimVarsCount() || dim_id < 0) {
65+
return absl::InvalidArgumentError(absl::StrFormat(
66+
"Dimension id %d is out of bounds for tile offsets indexing map with "
67+
"%d dimensions. This can happen if ",
68+
dim_id, tile_offsets_indexing.GetDimVarsCount()));
69+
}
70+
71+
if (absl::c_linear_search(iteration_space_dims, dim_id)) {
72+
return absl::InvalidArgumentError(absl::StrFormat(
73+
"Iteration space contains multiple dimensions with id %d.", dim_id));
74+
}
75+
iteration_space_dims.push_back(dim_id);
76+
}
77+
return absl::OkStatus();
78+
}
79+
80+
absl::StatusOr<IndexingMap> MajorToMinorScheduleImpl(
81+
const IndexingMap& tile_offsets_indexing, IterationSpace iteration_space,
82+
gpu::SymbolicExprContext* symbolic_expr_context) {
83+
mlir::MLIRContext* mlir_context = symbolic_expr_context->GetMLIRContext();
4884
mlir::AffineExpr program_id = mlir::getAffineDimExpr(0, mlir_context);
4985

5086
std::vector<int64_t> iteration_space_sizes;
@@ -60,12 +96,6 @@ absl::StatusOr<IndexingMap> MajorToMinorTiledHloSchedule::Schedule(
6096
for (auto [dim_info, tile_expr] : llvm::zip(
6197
iteration_space, DelinearizeIndex(iteration_space_sizes, program_id,
6298
symbolic_expr_context))) {
63-
if (dim_info.dimension_id >= tile_exprs.size()) {
64-
return absl::InvalidArgumentError(absl::StrFormat(
65-
"Dimension id %d is out of bounds for tile offsets indexing map with "
66-
"%d dimensions. This can happen if ",
67-
dim_info.dimension_id, tile_exprs.size()));
68-
}
6999
tile_exprs[dim_info.dimension_id] = tile_expr;
70100
}
71101
std::vector<IndexingMap::Variable> dim_vars{
@@ -81,5 +111,116 @@ absl::StatusOr<IndexingMap> MajorToMinorTiledHloSchedule::Schedule(
81111
scheduled_indexing.RemoveUnusedSymbols();
82112
return scheduled_indexing;
83113
}
114+
} // namespace
115+
116+
absl::StatusOr<IndexingMap> MajorToMinorTiledHloSchedule::Schedule(
117+
const IndexingMap& tile_offsets_indexing, IterationSpace iteration_space,
118+
gpu::SymbolicExprContext* ctx) const {
119+
TF_RETURN_IF_ERROR(
120+
ValidateIterationSpace(iteration_space, tile_offsets_indexing));
121+
return MajorToMinorScheduleImpl(tile_offsets_indexing, iteration_space, ctx);
122+
}
123+
124+
absl::StatusOr<TransposedDotTiledHloSchedule>
125+
TransposedDotTiledHloSchedule::Create(
126+
const TilingSpecification& tiling_specification) {
127+
const TilingSpecification::ParameterMapping& parameter_mapping =
128+
tiling_specification.parameter_mapping();
129+
CHECK(!parameter_mapping.empty());
130+
const HloDotInstruction* dot =
131+
::xla::DynCast<HloDotInstruction>(parameter_mapping.front().instruction);
132+
if (dot == nullptr) {
133+
return absl::InvalidArgumentError(
134+
absl::StrCat("TransposedDotTiledHloSchedule expects its root to be a "
135+
"dot instruction "
136+
"but got ",
137+
parameter_mapping.front().instruction->ToString()));
138+
}
139+
if (absl::c_any_of(absl::MakeSpan(parameter_mapping).subspan(1),
140+
[](const auto& param) {
141+
return param.instruction->opcode() == HloOpcode::kDot;
142+
})) {
143+
return absl::InvalidArgumentError(
144+
"TransposedDotTiledHloSchedule is only supported for "
145+
"TilingSpecifications specifying tiling for a single dot "
146+
"instruction.");
147+
}
148+
149+
int64_t num_lhs_non_contracting_dims =
150+
dot->operand(0)->shape().dimensions().size() -
151+
dot->dot_dimension_numbers().lhs_contracting_dimensions().size() -
152+
dot->dot_dimension_numbers().lhs_batch_dimensions().size();
153+
154+
int64_t num_rhs_non_contracting_dims =
155+
dot->operand(1)->shape().dimensions().size() -
156+
dot->dot_dimension_numbers().rhs_contracting_dimensions().size() -
157+
dot->dot_dimension_numbers().rhs_batch_dimensions().size();
158+
159+
constexpr absl::string_view kErrorFormat =
160+
"TransposedDotTiledHloSchedule is only supported for dot instructions "
161+
"with a single non-contracting dimension, but got %d non-contracting "
162+
"dimensions on the %s operand of %s.";
163+
164+
if (num_lhs_non_contracting_dims != 1) {
165+
return absl::InvalidArgumentError(absl::StrFormat(
166+
kErrorFormat, num_lhs_non_contracting_dims, "lhs", dot->ToString()));
167+
}
168+
169+
if (num_rhs_non_contracting_dims != 1) {
170+
return absl::InvalidArgumentError(absl::StrFormat(
171+
kErrorFormat, num_rhs_non_contracting_dims, "rhs", dot->ToString()));
172+
}
173+
174+
// The shape of the dot's output is now known to always be of the form
175+
// [..., m, n]. This is because batch dimensions precede non-contracting
176+
// dimensions, the lhs non-contracting dimensions precede the rhs
177+
// non-contracting dimensions, and there is exactly one such dimension on
178+
// each side.
179+
//
180+
// Figure out the parameter index of the m and n dimensions within the op.
181+
int64_t m_local_parameter_index =
182+
parameter_mapping.front().num_tiling_parameters - 2;
183+
int64_t n_local_parameter_index =
184+
parameter_mapping.front().num_tiling_parameters - 1;
185+
186+
// Using the local parameter index, we can compute the global parameter index
187+
// (i.e. the parameter index within the sequence of all tiling parameters).
188+
TF_ASSIGN_OR_RETURN(int64_t m_dim_id, tiling_specification.ParameterIndex(
189+
dot, m_local_parameter_index));
190+
TF_ASSIGN_OR_RETURN(int64_t n_dim_id, tiling_specification.ParameterIndex(
191+
dot, n_local_parameter_index));
192+
193+
return TransposedDotTiledHloSchedule(tiling_specification, m_dim_id,
194+
n_dim_id);
195+
}
196+
197+
absl::StatusOr<IndexingMap> TransposedDotTiledHloSchedule::Schedule(
198+
const IndexingMap& tile_offsets_indexing, IterationSpace iteration_space,
199+
gpu::SymbolicExprContext* ctx) const {
200+
CHECK_EQ(iteration_space.size(), tiling_specification_.num_parameters());
201+
TF_RETURN_IF_ERROR(
202+
ValidateIterationSpace(iteration_space, tile_offsets_indexing));
203+
204+
DimensionInfo m_dim_info = iteration_space[m_dim_id_];
205+
DimensionInfo n_dim_info = iteration_space[n_dim_id_];
206+
207+
if (m_dim_info.dimension_id != m_dim_id_) {
208+
return absl::InvalidArgumentError(absl::StrFormat(
209+
"Expected dimension at offset %d to have id %d but got %d.", m_dim_id_,
210+
m_dim_id_, m_dim_info.dimension_id));
211+
}
212+
if (n_dim_info.dimension_id != n_dim_id_) {
213+
return absl::InvalidArgumentError(absl::StrFormat(
214+
"Expected dimension at offset %d to have id %d but got %d.", n_dim_id_,
215+
n_dim_id_, n_dim_info.dimension_id));
216+
}
217+
218+
std::vector<DimensionInfo> transposed_iteration_space(iteration_space.begin(),
219+
iteration_space.end());
220+
transposed_iteration_space[m_dim_id_] = n_dim_info;
221+
transposed_iteration_space[n_dim_id_] = m_dim_info;
222+
return MajorToMinorScheduleImpl(tile_offsets_indexing,
223+
transposed_iteration_space, ctx);
224+
}
84225

85226
} // namespace xla

xla/codegen/tiling/tiled_hlo_schedule.h

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020

2121
#include "absl/status/statusor.h"
2222
#include "absl/types/span.h"
23+
#include "xla/codegen/tiling/tiling_specification.h"
2324
#include "xla/hlo/analysis/indexing_map.h"
2425
#include "xla/service/gpu/model/experimental/symbolic_expr.h"
2526

@@ -68,7 +69,7 @@ class TiledHloSchedule {
6869
// themselves);
6970
virtual absl::StatusOr<IndexingMap> Schedule(
7071
const IndexingMap& tile_offsets_indexing, IterationSpace iteration_space,
71-
gpu::SymbolicExprContext* symbolic_expr_context) const = 0;
72+
gpu::SymbolicExprContext* ctx) const = 0;
7273
};
7374

7475
// The indexing map returned by this schedule iterates over the iteration space
@@ -79,7 +80,45 @@ class MajorToMinorTiledHloSchedule : public TiledHloSchedule {
7980
public:
8081
absl::StatusOr<IndexingMap> Schedule(
8182
const IndexingMap& tile_offsets_indexing, IterationSpace iteration_space,
82-
gpu::SymbolicExprContext* symbolic_expr_context) const override;
83+
gpu::SymbolicExprContext* ctx) const override;
84+
};
85+
86+
// Given a `TilingSpecification` where some of the output tile sizes are
87+
// provided by a `dot` operation with one left-hand-side and one
88+
// right-hand-side non-contracting dimensions, this schedule transposes the
89+
// iteration pattern over these output dimensions.
90+
//
91+
// This schedule is only constructible when the underlying `TilingSpecification`
92+
// contains a single `dot` node.
93+
//
94+
// TODO(b/417977182): this is implemented as a very bespoke pattern to unblock
95+
// the launch of the generic emitter. We probably will want to subsume this with
96+
// a more flexible approach for user-specified transposed schedules (that don't
97+
// rely on the "dot" instruction being at the root).
98+
class TransposedDotTiledHloSchedule : public TiledHloSchedule {
99+
public:
100+
absl::StatusOr<IndexingMap> Schedule(
101+
const IndexingMap& tile_offsets_indexing, IterationSpace iteration_space,
102+
gpu::SymbolicExprContext* ctx) const override;
103+
104+
static absl::StatusOr<TransposedDotTiledHloSchedule> Create(
105+
const TilingSpecification& tiling_specification);
106+
107+
private:
108+
TransposedDotTiledHloSchedule(const TilingSpecification& tiling_specification,
109+
int64_t m_dim_id, int64_t n_dim_id)
110+
: tiling_specification_(tiling_specification),
111+
m_dim_id_(m_dim_id),
112+
n_dim_id_(n_dim_id) {}
113+
114+
// The `TilingSpecification` used to construct this schedule.
115+
TilingSpecification tiling_specification_;
116+
// The index of the `m` dimension within the parameter mapping of the
117+
// `TilingSpecification`.
118+
int64_t m_dim_id_;
119+
// The index of the `n` dimension within the parameter mapping of the
120+
// `TilingSpecification`.
121+
int64_t n_dim_id_;
83122
};
84123

85124
// TODO(b/417977182): implement the `PlanarSnakeTiledHloSchedule` schedule.

0 commit comments

Comments
 (0)