@@ -23,7 +23,7 @@ index 23644329d2..44ffa9c943 100644
23
23
+ )
24
24
diff --git a/requirements.in b/requirements.in
25
25
new file mode 100644
26
- index 0000000000..53c8b0828d
26
+ index 0000000000..66dd73a662
27
27
--- /dev/null
28
28
+++ b/requirements.in
29
29
@@ -0,0 +1,10 @@
@@ -37,7 +37,6 @@ index 0000000000..53c8b0828d
37
37
+ scipy<1.12.0
38
38
+ ml_dtypes>=0.4.0
39
39
+ lit
40
- \ No newline at end of file
41
40
diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt
42
41
new file mode 100644
43
42
index 0000000000..20bd3ed1de
@@ -1098,18 +1097,21 @@ index d7675e1b6a..17da528bec 100644
1098
1097
1099
1098
namespace xla {
1100
1099
diff --git a/xla/service/BUILD b/xla/service/BUILD
1101
- index a0b65b2012..73ad7b142f 100644
1100
+ index a0b65b2012..9fc6094208 100644
1102
1101
--- a/xla/service/BUILD
1103
1102
+++ b/xla/service/BUILD
1104
- @@ -3,6 +3,7 @@
1105
-
1106
- load("@bazel_skylib//rules:build_test.bzl", "build_test")
1107
- load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
1108
- + load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl_is_configured")
1109
- load(
1110
- "@local_config_rocm//rocm:build_defs.bzl",
1103
+ @@ -8,6 +8,10 @@ load(
1111
1104
"if_rocm",
1112
- @@ -1473,6 +1474,9 @@ cc_library(
1105
+ "if_rocm_is_configured",
1106
+ )
1107
+ + load(
1108
+ + "@local_config_sycl//sycl:build_defs.bzl",
1109
+ + "if_sycl_is_configured",
1110
+ + )
1111
+ load(
1112
+ "@tsl//tsl/platform:build_config.bzl",
1113
+ "tf_proto_library",
1114
+ @@ -1473,6 +1477,9 @@ cc_library(
1113
1115
]) + if_rocm_is_configured([
1114
1116
"//xla/service/gpu:amdgpu_compiler",
1115
1117
"//xla/stream_executor/rocm:stream_executor_rocm",
@@ -1119,15 +1121,15 @@ index a0b65b2012..73ad7b142f 100644
1119
1121
]),
1120
1122
)
1121
1123
1122
- @@ -4131,6 +4135 ,7 @@ cc_library(
1124
+ @@ -4131,6 +4138 ,7 @@ cc_library(
1123
1125
"//xla/stream_executor/cuda:cuda_platform_id",
1124
1126
"//xla/stream_executor/host:host_platform_id",
1125
1127
"//xla/stream_executor/rocm:rocm_platform_id",
1126
1128
+ "@intel_extension_for_openxla//xla/stream_executor/sycl:sycl_platform_id",
1127
1129
"@com_google_absl//absl/container:flat_hash_map",
1128
1130
"@com_google_absl//absl/memory",
1129
1131
"@com_google_absl//absl/status",
1130
- @@ -7572,6 +7577 ,19 @@ cc_library(
1132
+ @@ -7572,6 +7580 ,19 @@ cc_library(
1131
1133
],
1132
1134
)
1133
1135
@@ -1147,7 +1149,7 @@ index a0b65b2012..73ad7b142f 100644
1147
1149
cc_library(
1148
1150
name = "scatter_simplifier",
1149
1151
srcs = ["scatter_simplifier.cc"],
1150
- @@ -7653,8 +7671 ,10 @@ cc_library(
1152
+ @@ -7653,8 +7674 ,10 @@ cc_library(
1151
1153
deps = [
1152
1154
":hlo_creation_utils",
1153
1155
":hlo_pass",
@@ -1698,7 +1700,7 @@ index e9cb21b9fa..1ba8c60b50 100644
1698
1700
MakeGetTupleElementHlo(new_conv, 0));
1699
1701
TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr));
1700
1702
diff --git a/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/xla/service/gpu/cudnn_fused_mha_rewriter.cc
1701
- index f03fe4f0fa..468fa5c6dd 100644
1703
+ index f03fe4f0fa..d7975eba83 100644
1702
1704
--- a/xla/service/gpu/cudnn_fused_mha_rewriter.cc
1703
1705
+++ b/xla/service/gpu/cudnn_fused_mha_rewriter.cc
1704
1706
@@ -234,12 +234,14 @@ auto GetUnfusedReduceMaxSumSoftmaxPattern(
@@ -1794,18 +1796,18 @@ index f03fe4f0fa..468fa5c6dd 100644
1794
1796
if (is_flash_attention) {
1795
1797
if (is_causal_mask) {
1796
1798
// if bias is causal mask, needs to remove bias from name
1797
- @@ -1098,6 +1115,11 @@ absl::StatusOr<bool> IsMHABlockSupported(
1799
+ @@ -1097,6 +1114,11 @@ absl::StatusOr<bool> IsMHABlockSupported(
1800
+ }
1798
1801
}
1799
1802
}
1800
- return is_flash_attention;
1801
1803
+ #else
1802
- + if (!is_flash_attention || is_causal_mask) {
1804
+ + if (is_causal_mask) {
1803
1805
+ return false;
1804
- + } else return true;
1806
+ + }
1805
1807
+ #endif
1808
+ return is_flash_attention;
1806
1809
}
1807
1810
1808
- absl::StatusOr<HloInstruction*> CanonicalizeBatchedGemmForcuDNNFMHA(
1809
1811
@@ -1627,6 +1649,7 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
1810
1812
comp->parent()->config().debug_options();
1811
1813
const se::dnn::VersionInfo cudnn_version =
@@ -2146,7 +2148,7 @@ index cc01f915d0..109df9d880 100644
2146
2148
// kernels makes sense.
2147
2149
2148
2150
diff --git a/xla/service/gpu/fusions/reduction.cc b/xla/service/gpu/fusions/reduction.cc
2149
- index 881b17a52c..b5146a6a85 100644
2151
+ index 881b17a52c..f08d077cc4 100644
2150
2152
--- a/xla/service/gpu/fusions/reduction.cc
2151
2153
+++ b/xla/service/gpu/fusions/reduction.cc
2152
2154
@@ -1153,7 +1153,11 @@ ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis) {
@@ -2161,13 +2163,6 @@ index 881b17a52c..b5146a6a85 100644
2161
2163
if (reduction_dimensions.is_row_reduction &&
2162
2164
num_threads_x * 2 <= kThreadsPerBlockTarget) {
2163
2165
int64_t kept_size =
2164
- @@ -1324,4 +1328,4 @@ std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToInputIndexing(
2165
- }
2166
-
2167
- } // namespace gpu
2168
- - } // namespace xla
2169
- + } // namespace xla
2170
- \ No newline at end of file
2171
2166
diff --git a/xla/service/gpu/fusions/reduction_base.cc b/xla/service/gpu/fusions/reduction_base.cc
2172
2167
index 66b095a617..e680411f6e 100644
2173
2168
--- a/xla/service/gpu/fusions/reduction_base.cc
@@ -2179,17 +2174,6 @@ index 66b095a617..e680411f6e 100644
2179
2174
- } // namespace xla
2180
2175
+ } // namespace xla
2181
2176
\ No newline at end of file
2182
- diff --git a/xla/service/gpu/fusions/reduction_base.h b/xla/service/gpu/fusions/reduction_base.h
2183
- index 7bf4437dea..8d0487fdd9 100644
2184
- --- a/xla/service/gpu/fusions/reduction_base.h
2185
- +++ b/xla/service/gpu/fusions/reduction_base.h
2186
- @@ -53,4 +53,4 @@ void AddGroupIdConstraint(IndexingMap& map, int64_t root_index,
2187
- } // namespace gpu
2188
- } // namespace xla
2189
-
2190
- - #endif // XLA_SERVICE_GPU_FUSIONS_REDUCTION_BASE_H_
2191
- + #endif // XLA_SERVICE_GPU_FUSIONS_REDUCTION_BASE_H_
2192
- \ No newline at end of file
2193
2177
diff --git a/xla/service/gpu/fusions/reduction_mlir.cc b/xla/service/gpu/fusions/reduction_mlir.cc
2194
2178
index b3ec1f8e00..fefe200e34 100644
2195
2179
--- a/xla/service/gpu/fusions/reduction_mlir.cc
@@ -2957,7 +2941,7 @@ index 069ae1cf75..8440a33b4f 100644
2957
2941
rhs_shape.element_type() == S8);
2958
2942
2959
2943
diff --git a/xla/service/gpu/ir_emitter_context.cc b/xla/service/gpu/ir_emitter_context.cc
2960
- index 0a51d97622..040c9c8cdd 100644
2944
+ index 0a51d97622..03672a9fc8 100644
2961
2945
--- a/xla/service/gpu/ir_emitter_context.cc
2962
2946
+++ b/xla/service/gpu/ir_emitter_context.cc
2963
2947
@@ -67,6 +67,8 @@ void IrEmitterContext::emit_constant(int64_t num_elements,
@@ -2980,11 +2964,10 @@ index 0a51d97622..040c9c8cdd 100644
2980
2964
// These globals will be looked up by name by GpuExecutable so we need to
2981
2965
// give them an external linkage. Not all of their uses are visible in
2982
2966
// the LLVM IR so we can't give then a linkage that merely preserves their
2983
- @@ -95,7 +97,28 @@ void IrEmitterContext::emit_constant(int64_t num_elements,
2967
+ @@ -95,6 +97,27 @@ void IrEmitterContext::emit_constant(int64_t num_elements,
2984
2968
/*AddressSpace=*/addrspace,
2985
2969
/*isExternallyInitialized=*/false);
2986
2970
global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
2987
- - llvm_module_constants()->insertGlobalVariable(global_for_const);
2988
2971
+
2989
2972
+ if (is_spir) {
2990
2973
+ // SYCL: Add spirv.Decorations for global variable. See document about the
@@ -3006,10 +2989,9 @@ index 0a51d97622..040c9c8cdd 100644
3006
2989
+ llvm::MDNode* md_list = llvm::MDNode::get(context, metadatas);
3007
2990
+ global_for_const->setMetadata("spirv.Decorations", md_list);
3008
2991
+ }
3009
- + llvm_module_ ->insertGlobalVariable(global_for_const);
2992
+ llvm_module_constants() ->insertGlobalVariable(global_for_const);
3010
2993
3011
2994
info.symbol_name.assign(symbol_name);
3012
- info.allocation_index = allocation_idx;
3013
2995
diff --git a/xla/service/gpu/ir_emitter_context.h b/xla/service/gpu/ir_emitter_context.h
3014
2996
index b1a7760c6a..5fd29e2cc6 100644
3015
2997
--- a/xla/service/gpu/ir_emitter_context.h
@@ -4006,7 +3988,7 @@ index 22d7f17813..5d64dfd606 100644
4006
3988
namespace gpu {
4007
3989
4008
3990
diff --git a/xla/service/gpu/model/gpu_performance_model_base.cc b/xla/service/gpu/model/gpu_performance_model_base.cc
4009
- index 2d0fc0ab3d..c7b888e3e6 100644
3991
+ index 2d0fc0ab3d..5940bcf7ce 100644
4010
3992
--- a/xla/service/gpu/model/gpu_performance_model_base.cc
4011
3993
+++ b/xla/service/gpu/model/gpu_performance_model_base.cc
4012
3994
@@ -77,8 +77,12 @@ int GetCoalescingWasteFactor(PrimitiveType element_type,
@@ -4022,16 +4004,14 @@ index 2d0fc0ab3d..c7b888e3e6 100644
4022
4004
float max_bandwidth = num_blocks * per_block_bandwidth;
4023
4005
4024
4006
return std::min(bandwidth, max_bandwidth);
4025
- @@ -175,7 +179,12 @@ LaunchDimensions GpuPerformanceModelBase::EstimateFusionLaunchDimensions(
4026
- // threads per block. In multi-output fusions, only look at one root.
4007
+ @@ -176,6 +180,11 @@ LaunchDimensions GpuPerformanceModelBase::EstimateFusionLaunchDimensions(
4027
4008
VLOG(5) << "Using fallback launch dimensions estimate for "
4028
4009
<< fusion_analysis.fusion().ToString();
4029
- - int64_t num_threads_per_block = 128;
4010
+ int64_t num_threads_per_block = 128;
4030
4011
+ #if TENSORFLOW_USE_SYCL
4031
4012
+ const se::DeviceDescription device_info = fusion_analysis.device_info();
4032
- + int64_t num_threads_per_block = RoundUpTo(device_info.threads_per_block_limit(),int64_t{32});
4033
- + #else
4034
- + int64_t num_threads_per_block = 128; // Result for default LaunchDimensionsConfig.
4013
+ + num_threads_per_block =
4014
+ + RoundUpTo(device_info.threads_per_block_limit(), int64_t{32});
4035
4015
+ #endif
4036
4016
int64_t estimated_num_threads =
4037
4017
ShapeUtil::ElementsInRecursive(fusion_analysis.fusion_root(0).shape());
@@ -5496,18 +5476,6 @@ index bd01a076cc..a99d49e837 100644
5496
5476
}
5497
5477
5498
5478
} // namespace gpu
5499
- diff --git a/xla/service/instruction_fusion.cc b/xla/service/instruction_fusion.cc
5500
- index 0439364a06..f83383564f 100644
5501
- --- a/xla/service/instruction_fusion.cc
5502
- +++ b/xla/service/instruction_fusion.cc
5503
- @@ -938,7 +938,6 @@ bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer,
5504
- };
5505
- // A common special case is a slice or dynamic-slice and a
5506
- // dynamic-update-slice that use the same indices. This pattern is safe.
5507
- -
5508
- auto is_nonelementwise_op = [](const HloInstruction* inst) {
5509
- return inst->opcode() != HloOpcode::kFusion && !inst->IsElementwise() &&
5510
- inst->opcode() != HloOpcode::kBitcast &&
5511
5479
diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc
5512
5480
index 1992f0dea0..e72edb2840 100644
5513
5481
--- a/xla/service/llvm_ir/llvm_util.cc
0 commit comments