@@ -3232,7 +3232,7 @@ index c888aff30..0fd7254d9 100644
3232
3232
// Extract the memory value returned from atomicCAS and store it as
3233
3233
// cas_old_output.
3234
3234
diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc
3235
- index 86426953a..620f092c7 100644
3235
+ index 86426953a..afd144957 100644
3236
3236
--- a/xla/service/gpu/ir_emitter_unnested.cc
3237
3237
+++ b/xla/service/gpu/ir_emitter_unnested.cc
3238
3238
@@ -1,4 +1,6 @@
@@ -3852,7 +3852,7 @@ index 86426953a..620f092c7 100644
3852
3852
static absl::StatusOr<CustomCallThunk::AttributesMap> BuildAttributesMap(
3853
3853
mlir::DictionaryAttr dict) {
3854
3854
CustomCallThunk::AttributesMap attributes;
3855
- @@ -1314,6 +1337,103 @@ static absl::StatusOr<CustomCallThunk::AttributesMap> BuildAttributesMap(
3855
+ @@ -1314,6 +1337,106 @@ static absl::StatusOr<CustomCallThunk::AttributesMap> BuildAttributesMap(
3856
3856
}
3857
3857
return attributes;
3858
3858
}
@@ -3862,9 +3862,17 @@ index 86426953a..620f092c7 100644
3862
3862
+ // After 0.4.26, ffi support absl::span.
3863
3863
+ // Below attrs can be refine to absl::span for reducing key-value
3864
3864
+ CustomCallThunk::AttributesMap attrs;
3865
- + attrs["backend_config_str"] = instr->raw_backend_config_string();
3866
3865
+ if (IsCustomCallToDnnConvolution(*instr)) {
3866
+ + TF_ASSIGN_OR_RETURN(auto gpu_config,
3867
+ + instr->backend_config<GpuBackendConfig>());
3868
+ + const CudnnConvBackendConfig& backend_config =
3869
+ + gpu_config.cudnn_conv_backend_config();
3867
3870
+ TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr));
3871
+ + attrs["conv_result_scale"] = static_cast<float>(backend_config.conv_result_scale());
3872
+ + attrs["side_input_scale"] = static_cast<float>(backend_config.side_input_scale());
3873
+ + attrs["activation_mode"] = static_cast<int32_t>(backend_config.activation_mode());
3874
+ + attrs["leakyrelu_alpha"] = static_cast<float>(backend_config.leakyrelu_alpha());
3875
+ +
3868
3876
+ const Window& window = instr->window();
3869
3877
+ const ConvolutionDimensionNumbers& dnums =
3870
3878
+ instr->convolution_dimension_numbers();
@@ -3930,23 +3938,18 @@ index 86426953a..620f092c7 100644
3930
3938
+ attrs["filter_dl"] = static_cast<int32_t>(filter_dl);
3931
3939
+ attrs["output_dl"] = static_cast<int32_t>(output_dl);
3932
3940
+ } else if (IsLegacyCublasMatmul(*instr) || IsCublasLtMatmul(*instr)) {
3933
- + const Shape& lhs_shape = instr->operand(0)->shape();
3934
- + const Shape& rhs_shape = instr->operand(1)->shape();
3935
- + const Shape& output_shape = instr->shape().IsTuple()
3936
- + ? instr->shape().tuple_shapes(0)
3937
- + : instr->shape();
3938
- + for (int i = 0; i < lhs_shape.layout().minor_to_major().size(); ++i) {
3939
- + attrs["lhs_minor_to_major_" + std::to_string(i)] =
3940
- + lhs_shape.layout().minor_to_major()[i];
3941
- + }
3942
- + for (int i = 0; i < rhs_shape.layout().minor_to_major().size(); ++i) {
3943
- + attrs["rhs_minor_to_major_" + std::to_string(i)] =
3944
- + rhs_shape.layout().minor_to_major()[i];
3945
- + }
3946
- + for (int i = 0; i < output_shape.layout().minor_to_major().size(); ++i) {
3947
- + attrs["output_minor_to_major_" + std::to_string(i)] =
3948
- + output_shape.layout().minor_to_major()[i];
3949
- + }
3941
+ + TF_ASSIGN_OR_RETURN(const auto gpu_config,
3942
+ + instr->backend_config<xla::gpu::GpuBackendConfig>());
3943
+ + xla::gpu::GemmBackendConfig config = gpu_config.gemm_backend_config();
3944
+ + xla::gpu::GemmBackendConfig_Epilogue epilogue = config.epilogue();
3945
+ + TF_ASSIGN_OR_RETURN(
3946
+ + auto gemm_config,
3947
+ + GemmConfig::For(static_cast<const HloInstruction*>(instr)));
3948
+ + GemmConfig* gemm_config_ptr = new GemmConfig(gemm_config);
3949
+ + attrs["epilogue"] = static_cast<int32_t>(epilogue);
3950
+ + // SYCL TODO:
3951
+ + // gemm_config may be split into separate parameters and added to attrs later.
3952
+ + attrs["gemm_config_ptr"] = reinterpret_cast<int64_t>(gemm_config_ptr);
3950
3953
+ } else {
3951
3954
+ return absl::InternalError("Unknown CustomCall To SYCL FFI Call");
3952
3955
+ }
@@ -3956,15 +3959,15 @@ index 86426953a..620f092c7 100644
3956
3959
3957
3960
absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3958
3961
const HloCustomCallInstruction* instr) {
3959
- @@ -1433,6 +1553 ,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3962
+ @@ -1433,6 +1556 ,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3960
3963
}
3961
3964
3962
3965
auto& backend_config_str = instr->raw_backend_config_string();
3963
3966
+
3964
3967
switch (instr->api_version()) {
3965
3968
case CustomCallApiVersion::API_VERSION_ORIGINAL:
3966
3969
case CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
3967
- @@ -1443,6 +1564 ,12 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3970
+ @@ -1443,6 +1567 ,12 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3968
3971
break;
3969
3972
3970
3973
case CustomCallApiVersion::API_VERSION_TYPED_FFI:
@@ -3977,7 +3980,7 @@ index 86426953a..620f092c7 100644
3977
3980
if (!backend_config_str.empty()) {
3978
3981
mlir::Attribute attr = mlir::parseAttribute(
3979
3982
backend_config_str, ir_emitter_context_->mlir_context());
3980
- @@ -1455,7 +1582 ,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3983
+ @@ -1455,7 +1585 ,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3981
3984
"dictionary attribute");
3982
3985
}
3983
3986
break;
@@ -3986,7 +3989,7 @@ index 86426953a..620f092c7 100644
3986
3989
default:
3987
3990
return Internal("Unknown custom-call API version enum value: %d",
3988
3991
instr->api_version());
3989
- @@ -1496,7 +1623 ,7 @@ absl::Status IrEmitterUnnested::EmitFftThunk(const HloFftInstruction* instr) {
3992
+ @@ -1496,7 +1626 ,7 @@ absl::Status IrEmitterUnnested::EmitFftThunk(const HloFftInstruction* instr) {
3990
3993
return absl::OkStatus();
3991
3994
}
3992
3995
@@ -3995,7 +3998,7 @@ index 86426953a..620f092c7 100644
3995
3998
3996
3999
absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall(
3997
4000
const HloInstruction* instr) {
3998
- @@ -1576,7 +1703 ,7 @@ absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall(
4001
+ @@ -1576,7 +1706 ,7 @@ absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall(
3999
4002
}
4000
4003
return absl::OkStatus();
4001
4004
}
@@ -4004,7 +4007,7 @@ index 86426953a..620f092c7 100644
4004
4007
4005
4008
absl::Status IrEmitterUnnested::EmitTopKCustomCall(
4006
4009
const HloCustomCallInstruction* instr) {
4007
- @@ -2602,33 +2729 ,33 @@ absl::Status IrEmitterUnnested::EmitCopyStartThunk(
4010
+ @@ -2602,33 +2732 ,33 @@ absl::Status IrEmitterUnnested::EmitCopyStartThunk(
4008
4011
}
4009
4012
4010
4013
absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
@@ -4064,7 +4067,7 @@ index 86426953a..620f092c7 100644
4064
4067
4065
4068
return absl::OkStatus();
4066
4069
}
4067
- @@ -2650,33 +2777 ,33 @@ absl::Status IrEmitterUnnested::EmitSendDoneThunk(
4070
+ @@ -2650,33 +2780 ,33 @@ absl::Status IrEmitterUnnested::EmitSendDoneThunk(
4068
4071
}
4069
4072
4070
4073
absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) {
@@ -4124,7 +4127,7 @@ index 86426953a..620f092c7 100644
4124
4127
4125
4128
return absl::OkStatus();
4126
4129
}
4127
- @@ -2798,13 +2925 ,31 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
4130
+ @@ -2798,13 +2928 ,31 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
4128
4131
case HloOpcode::kCustomCall: {
4129
4132
auto* custom_call = Cast<HloCustomCallInstruction>(instr);
4130
4133
if (IsLegacyCublasMatmul(*instr)) {
@@ -4158,7 +4161,7 @@ index 86426953a..620f092c7 100644
4158
4161
#if GOOGLE_CUDA
4159
4162
if (IsCublasLtMatmulF8(*instr)) {
4160
4163
return EmitCublasLtMatmulThunkF8(custom_call);
4161
- @@ -2815,30 +2960 ,32 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
4164
+ @@ -2815,30 +2963 ,32 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
4162
4165
if (IsCustomCallToDnnNorm(*instr)) {
4163
4166
return EmitNormThunk(custom_call);
4164
4167
}
@@ -4753,10 +4756,18 @@ index 15e6e6692..db8ebb0e1 100644
4753
4756
} // namespace xla::gpu
4754
4757
4755
4758
diff --git a/xla/service/gpu/runtime/BUILD b/xla/service/gpu/runtime/BUILD
4756
- index 4f298bbe0..39b81b2c7 100644
4759
+ index 4f298bbe0..1a45c9a12 100644
4757
4760
--- a/xla/service/gpu/runtime/BUILD
4758
4761
+++ b/xla/service/gpu/runtime/BUILD
4759
- @@ -427,6 +427,7 @@ cc_library(
4762
+ @@ -381,6 +381,7 @@ cc_library(
4763
+ "//xla/service:custom_call_status",
4764
+ "//xla/service:custom_call_status_internal",
4765
+ "//xla/service:executable",
4766
+ + "//xla/service/gpu:matmul_utils",
4767
+ "//xla/service/gpu:thunk",
4768
+ "//xla/stream_executor:device_memory",
4769
+ "//xla/stream_executor/gpu:gpu_stream_header",
4770
+ @@ -427,6 +428,7 @@ cc_library(
4760
4771
"//xla/service/gpu:thunk",
4761
4772
"//xla/stream_executor",
4762
4773
"@com_google_absl//absl/container:flat_hash_map",
@@ -4922,10 +4933,17 @@ index 02aecd464..df9213bae 100644
4922
4933
absl::flat_hash_map<const stream_executor::Stream*,
4923
4934
std::unique_ptr<GenericConvRunner>>
4924
4935
diff --git a/xla/service/gpu/runtime/custom_call_thunk.cc b/xla/service/gpu/runtime/custom_call_thunk.cc
4925
- index 28a7dcebf..4c8727f6b 100644
4936
+ index 28a7dcebf..faaa07689 100644
4926
4937
--- a/xla/service/gpu/runtime/custom_call_thunk.cc
4927
4938
+++ b/xla/service/gpu/runtime/custom_call_thunk.cc
4928
- @@ -36,7 +36,7 @@ limitations under the License.
4939
+ @@ -30,13 +30,14 @@ limitations under the License.
4940
+ #include "xla/service/buffer_assignment.h"
4941
+ #include "xla/service/custom_call_status.h"
4942
+ #include "xla/service/custom_call_status_internal.h"
4943
+ + #include "xla/service/gpu/matmul_utils.h"
4944
+ #include "xla/service/gpu/thunk.h"
4945
+ #include "xla/service/service_executable_run_options.h"
4946
+ #include "xla/status.h"
4929
4947
#include "xla/stream_executor/device_memory.h"
4930
4948
#include "xla/util.h"
4931
4949
@@ -4934,7 +4952,26 @@ index 28a7dcebf..4c8727f6b 100644
4934
4952
#include "xla/stream_executor/gpu/gpu_stream.h"
4935
4953
#endif
4936
4954
4937
- @@ -89,7 +89,7 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4955
+ @@ -70,6 +71,18 @@ CustomCallThunk::CustomCallThunk(ThunkInfo thunk_info, XLA_FFI_Handler* handler,
4956
+ attributes_(std::move(attributes)),
4957
+ called_computation_(called_computation) {}
4958
+
4959
+ + #ifdef TENSORFLOW_USE_SYCL
4960
+ + CustomCallThunk::~CustomCallThunk(){
4961
+ + if(attributes_.find("gemm_config_ptr") != attributes_.end()){
4962
+ + GemmConfig* gemm_config_ptr =
4963
+ + reinterpret_cast<GemmConfig*>(std::get<int64_t>(attributes_["gemm_config_ptr"]));
4964
+ + if(gemm_config_ptr != nullptr){
4965
+ + delete gemm_config_ptr;
4966
+ + }
4967
+ + }
4968
+ + }
4969
+ + #endif
4970
+ +
4971
+ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4972
+ // gpu_stream is CUstream or e.g. the equivalent type in ROCm.
4973
+ std::vector<void*> buffers;
4974
+ @@ -89,7 +102,7 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4938
4975
}
4939
4976
}
4940
4977
@@ -4943,7 +4980,7 @@ index 28a7dcebf..4c8727f6b 100644
4943
4980
auto gpu_stream = se::gpu::AsGpuStreamValue(params.stream);
4944
4981
XlaCustomCallStatus custom_call_status;
4945
4982
call_target_(gpu_stream, buffers.data(), opaque_.data(), opaque_.size(),
4946
- @@ -100,11 +100 ,11 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4983
+ @@ -100,11 +113 ,11 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4947
4984
} else {
4948
4985
return absl::OkStatus();
4949
4986
}
@@ -4957,7 +4994,7 @@ index 28a7dcebf..4c8727f6b 100644
4957
4994
}
4958
4995
4959
4996
absl::Status CustomCallThunk::ExecuteFfiHandler(const ExecuteParams& params) {
4960
- @@ -139,6 +139 ,10 @@ absl::Status CustomCallThunk::ExecuteFfiHandler(const ExecuteParams& params) {
4997
+ @@ -139,6 +152 ,10 @@ absl::Status CustomCallThunk::ExecuteFfiHandler(const ExecuteParams& params) {
4961
4998
// execution context, as apparently it's not easily accessible from Thunk.
4962
4999
ExecutableRunOptions run_options;
4963
5000
run_options.set_stream(params.stream);
@@ -4969,7 +5006,7 @@ index 28a7dcebf..4c8727f6b 100644
4969
5006
4970
5007
CallOptions options = {&service_run_options, called_computation_};
4971
5008
diff --git a/xla/service/gpu/runtime/custom_call_thunk.h b/xla/service/gpu/runtime/custom_call_thunk.h
4972
- index 5fa1dce32..e75b61636 100644
5009
+ index 5fa1dce32..03a61bd0c 100644
4973
5010
--- a/xla/service/gpu/runtime/custom_call_thunk.h
4974
5011
+++ b/xla/service/gpu/runtime/custom_call_thunk.h
4975
5012
@@ -35,7 +35,7 @@ limitations under the License.
@@ -4996,6 +5033,17 @@ index 5fa1dce32..e75b61636 100644
4996
5033
4997
5034
using CustomCallTarget = std::function<void(Stream, void**, const char*,
4998
5035
size_t, XlaCustomCallStatus*)>;
5036
+ @@ -94,6 +94,10 @@ class CustomCallThunk : public Thunk {
5037
+ const std::vector<std::optional<Slice>>& results() const { return results_; }
5038
+ absl::string_view opaque() const { return opaque_; }
5039
+
5040
+ + #ifdef TENSORFLOW_USE_SYCL
5041
+ + ~CustomCallThunk();
5042
+ + #endif
5043
+ +
5044
+ private:
5045
+ absl::Status ExecuteCustomCall(const ExecuteParams& params);
5046
+ absl::Status ExecuteFfiHandler(const ExecuteParams& params);
4999
5047
diff --git a/xla/service/gpu/runtime/fft_thunk.cc b/xla/service/gpu/runtime/fft_thunk.cc
5000
5048
index 728c36752..fccde5793 100644
5001
5049
--- a/xla/service/gpu/runtime/fft_thunk.cc
0 commit comments