@@ -18,25 +18,37 @@ limitations under the License.
1818#include < cstdint>
1919#include < vector>
2020
21+ #include " absl/algorithm/container.h"
2122#include " absl/log/check.h"
2223#include " absl/status/status.h"
2324#include " absl/status/statusor.h"
25+ #include " absl/strings/str_cat.h"
2426#include " absl/strings/str_format.h"
27+ #include " absl/strings/string_view.h"
28+ #include " absl/types/span.h"
2529#include " llvm/ADT/STLExtras.h"
2630#include " mlir/IR/AffineExpr.h"
2731#include " mlir/IR/AffineMap.h"
2832#include " mlir/IR/MLIRContext.h"
33+ #include " xla/codegen/tiling/tiling_specification.h"
2934#include " xla/hlo/analysis/indexing_analysis.h"
3035#include " xla/hlo/analysis/indexing_map.h"
36+ #include " xla/hlo/ir/hlo_casting_utils.h"
37+ #include " xla/hlo/ir/hlo_instructions.h"
38+ #include " xla/hlo/ir/hlo_opcode.h"
3139#include " xla/service/gpu/model/experimental/symbolic_expr.h"
40+ #include " xla/tsl/platform/errors.h"
41+ #include " xla/tsl/platform/statusor.h"
3242#include " xla/util.h"
3343
3444namespace xla {
3545
36- absl::StatusOr<IndexingMap> MajorToMinorTiledHloSchedule::Schedule (
37- const IndexingMap& tile_offsets_indexing, IterationSpace iteration_space,
38- gpu::SymbolicExprContext* symbolic_expr_context) const {
39- mlir::MLIRContext* mlir_context = symbolic_expr_context->GetMLIRContext ();
46+ namespace {
47+
48+ // Helper to validate that an iteration space is compatible with a tile offsets
49+ // indexing map.
50+ absl::Status ValidateIterationSpace (const IterationSpace& iteration_space,
51+ const IndexingMap& tile_offsets_indexing) {
4052 if (iteration_space.size () != tile_offsets_indexing.GetDimVarsCount ()) {
4153 return absl::InvalidArgumentError (absl::StrFormat (
4254 " Expected iteration space to have exactly as many dimensions as there "
@@ -45,6 +57,30 @@ absl::StatusOr<IndexingMap> MajorToMinorTiledHloSchedule::Schedule(
4557 iteration_space.size (), tile_offsets_indexing.GetDimVarsCount ()));
4658 }
4759
60+ std::vector<int64_t > iteration_space_dims;
61+ iteration_space_dims.reserve (iteration_space.size ());
62+
63+ for (const auto & [dim_id, dim_size] : iteration_space) {
64+ if (dim_id >= tile_offsets_indexing.GetDimVarsCount () || dim_id < 0 ) {
65+ return absl::InvalidArgumentError (absl::StrFormat (
66+ " Dimension id %d is out of bounds for tile offsets indexing map with "
67+ " %d dimensions. This can happen if " ,
68+ dim_id, tile_offsets_indexing.GetDimVarsCount ()));
69+ }
70+
71+ if (absl::c_linear_search (iteration_space_dims, dim_id)) {
72+ return absl::InvalidArgumentError (absl::StrFormat (
73+ " Iteration space contains multiple dimensions with id %d." , dim_id));
74+ }
75+ iteration_space_dims.push_back (dim_id);
76+ }
77+ return absl::OkStatus ();
78+ }
79+
80+ absl::StatusOr<IndexingMap> MajorToMinorScheduleImpl (
81+ const IndexingMap& tile_offsets_indexing, IterationSpace iteration_space,
82+ gpu::SymbolicExprContext* symbolic_expr_context) {
83+ mlir::MLIRContext* mlir_context = symbolic_expr_context->GetMLIRContext ();
4884 mlir::AffineExpr program_id = mlir::getAffineDimExpr (0 , mlir_context);
4985
5086 std::vector<int64_t > iteration_space_sizes;
@@ -60,12 +96,6 @@ absl::StatusOr<IndexingMap> MajorToMinorTiledHloSchedule::Schedule(
6096 for (auto [dim_info, tile_expr] : llvm::zip (
6197 iteration_space, DelinearizeIndex (iteration_space_sizes, program_id,
6298 symbolic_expr_context))) {
63- if (dim_info.dimension_id >= tile_exprs.size ()) {
64- return absl::InvalidArgumentError (absl::StrFormat (
65- " Dimension id %d is out of bounds for tile offsets indexing map with "
66- " %d dimensions. This can happen if " ,
67- dim_info.dimension_id , tile_exprs.size ()));
68- }
6999 tile_exprs[dim_info.dimension_id ] = tile_expr;
70100 }
71101 std::vector<IndexingMap::Variable> dim_vars{
@@ -81,5 +111,116 @@ absl::StatusOr<IndexingMap> MajorToMinorTiledHloSchedule::Schedule(
81111 scheduled_indexing.RemoveUnusedSymbols ();
82112 return scheduled_indexing;
83113}
114+ } // namespace
115+
116+ absl::StatusOr<IndexingMap> MajorToMinorTiledHloSchedule::Schedule (
117+ const IndexingMap& tile_offsets_indexing, IterationSpace iteration_space,
118+ gpu::SymbolicExprContext* ctx) const {
119+ TF_RETURN_IF_ERROR (
120+ ValidateIterationSpace (iteration_space, tile_offsets_indexing));
121+ return MajorToMinorScheduleImpl (tile_offsets_indexing, iteration_space, ctx);
122+ }
123+
124+ absl::StatusOr<TransposedDotTiledHloSchedule>
125+ TransposedDotTiledHloSchedule::Create (
126+ const TilingSpecification& tiling_specification) {
127+ const TilingSpecification::ParameterMapping& parameter_mapping =
128+ tiling_specification.parameter_mapping ();
129+ CHECK (!parameter_mapping.empty ());
130+ const HloDotInstruction* dot =
131+ ::xla::DynCast<HloDotInstruction>(parameter_mapping.front ().instruction );
132+ if (dot == nullptr ) {
133+ return absl::InvalidArgumentError (
134+ absl::StrCat (" TransposedDotTiledHloSchedule expects its root to be a "
135+ " dot instruction "
136+ " but got " ,
137+ parameter_mapping.front ().instruction ->ToString ()));
138+ }
139+ if (absl::c_any_of (absl::MakeSpan (parameter_mapping).subspan (1 ),
140+ [](const auto & param) {
141+ return param.instruction ->opcode () == HloOpcode::kDot ;
142+ })) {
143+ return absl::InvalidArgumentError (
144+ " TransposedDotTiledHloSchedule is only supported for "
145+ " TilingSpecifications specifying tiling for a single dot "
146+ " instruction." );
147+ }
148+
149+ int64_t num_lhs_non_contracting_dims =
150+ dot->operand (0 )->shape ().dimensions ().size () -
151+ dot->dot_dimension_numbers ().lhs_contracting_dimensions ().size () -
152+ dot->dot_dimension_numbers ().lhs_batch_dimensions ().size ();
153+
154+ int64_t num_rhs_non_contracting_dims =
155+ dot->operand (1 )->shape ().dimensions ().size () -
156+ dot->dot_dimension_numbers ().rhs_contracting_dimensions ().size () -
157+ dot->dot_dimension_numbers ().rhs_batch_dimensions ().size ();
158+
159+ constexpr absl::string_view kErrorFormat =
160+ " TransposedDotTiledHloSchedule is only supported for dot instructions "
161+ " with a single non-contracting dimension, but got %d non-contracting "
162+ " dimensions on the %s operand of %s." ;
163+
164+ if (num_lhs_non_contracting_dims != 1 ) {
165+ return absl::InvalidArgumentError (absl::StrFormat (
166+ kErrorFormat , num_lhs_non_contracting_dims, " lhs" , dot->ToString ()));
167+ }
168+
169+ if (num_rhs_non_contracting_dims != 1 ) {
170+ return absl::InvalidArgumentError (absl::StrFormat (
171+ kErrorFormat , num_rhs_non_contracting_dims, " rhs" , dot->ToString ()));
172+ }
173+
174+ // The shape of the dot's output is now known to always be of the form
175+ // [..., m, n]. This is because batch dimensions precede non-contracting
176+ // dimensions, the lhs non-contracting dimensions precede the rhs
177+ // non-contracting dimensions, and there is exactly one such dimension on
178+ // each side.
179+ //
180+ // Figure out the parameter index of the m and n dimensions within the op.
181+ int64_t m_local_parameter_index =
182+ parameter_mapping.front ().num_tiling_parameters - 2 ;
183+ int64_t n_local_parameter_index =
184+ parameter_mapping.front ().num_tiling_parameters - 1 ;
185+
186+ // Using the local parameter index, we can compute the global parameter index
187+ // (i.e. the parameter index within the sequence of all tiling parameters).
188+ TF_ASSIGN_OR_RETURN (int64_t m_dim_id, tiling_specification.ParameterIndex (
189+ dot, m_local_parameter_index));
190+ TF_ASSIGN_OR_RETURN (int64_t n_dim_id, tiling_specification.ParameterIndex (
191+ dot, n_local_parameter_index));
192+
193+ return TransposedDotTiledHloSchedule (tiling_specification, m_dim_id,
194+ n_dim_id);
195+ }
196+
197+ absl::StatusOr<IndexingMap> TransposedDotTiledHloSchedule::Schedule (
198+ const IndexingMap& tile_offsets_indexing, IterationSpace iteration_space,
199+ gpu::SymbolicExprContext* ctx) const {
200+ CHECK_EQ (iteration_space.size (), tiling_specification_.num_parameters ());
201+ TF_RETURN_IF_ERROR (
202+ ValidateIterationSpace (iteration_space, tile_offsets_indexing));
203+
204+ DimensionInfo m_dim_info = iteration_space[m_dim_id_];
205+ DimensionInfo n_dim_info = iteration_space[n_dim_id_];
206+
207+ if (m_dim_info.dimension_id != m_dim_id_) {
208+ return absl::InvalidArgumentError (absl::StrFormat (
209+ " Expected dimension at offset %d to have id %d but got %d." , m_dim_id_,
210+ m_dim_id_, m_dim_info.dimension_id ));
211+ }
212+ if (n_dim_info.dimension_id != n_dim_id_) {
213+ return absl::InvalidArgumentError (absl::StrFormat (
214+ " Expected dimension at offset %d to have id %d but got %d." , n_dim_id_,
215+ n_dim_id_, n_dim_info.dimension_id ));
216+ }
217+
218+ std::vector<DimensionInfo> transposed_iteration_space (iteration_space.begin (),
219+ iteration_space.end ());
220+ transposed_iteration_space[m_dim_id_] = n_dim_info;
221+ transposed_iteration_space[n_dim_id_] = m_dim_info;
222+ return MajorToMinorScheduleImpl (tile_offsets_indexing,
223+ transposed_iteration_space, ctx);
224+ }
84225
85226} // namespace xla
0 commit comments