@@ -3012,7 +3012,7 @@ index 7403531a8..cc3b8aadd 100644
3012
3012
// Extract the memory value returned from atomicCAS and store it as
3013
3013
// cas_old_output.
3014
3014
diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc
3015
- index 4ba739dba..475cd1217 100644
3015
+ index 4ba739dba..a84be445e 100644
3016
3016
--- a/xla/service/gpu/ir_emitter_unnested.cc
3017
3017
+++ b/xla/service/gpu/ir_emitter_unnested.cc
3018
3018
@@ -1,4 +1,6 @@
@@ -3240,7 +3240,7 @@ index 4ba739dba..475cd1217 100644
3240
3240
absl::Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) {
3241
3241
TF_ASSIGN_OR_RETURN(CholeskyOptions options,
3242
3242
instr->backend_config<CholeskyOptions>());
3243
- @@ -1301,7 +1333,108 @@ absl::Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) {
3243
+ @@ -1301,7 +1333,144 @@ absl::Status IrEmitterUnnested::EmitCholeskyThunk(const HloInstruction* instr) {
3244
3244
3245
3245
return absl::OkStatus();
3246
3246
}
@@ -3250,19 +3250,21 @@ index 4ba739dba..475cd1217 100644
3250
3250
+ #ifdef TENSORFLOW_USE_SYCL
3251
3251
+ absl::StatusOr<CustomCallThunk::AttributesMap> BuildAttributesMap(
3252
3252
+ const HloCustomCallInstruction* instr) {
3253
- + // After 0.4.26, ffi support absl::span.
3254
- + // Below attrs can be refine to absl::span for reducing key-value
3255
3253
+ CustomCallThunk::AttributesMap attrs;
3256
3254
+ if (IsCustomCallToDnnConvolution(*instr)) {
3257
3255
+ TF_ASSIGN_OR_RETURN(auto gpu_config,
3258
3256
+ instr->backend_config<GpuBackendConfig>());
3259
3257
+ const CudnnConvBackendConfig& backend_config =
3260
3258
+ gpu_config.cudnn_conv_backend_config();
3261
3259
+ TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr));
3262
- + attrs["conv_result_scale"] = static_cast<float>(backend_config.conv_result_scale());
3263
- + attrs["side_input_scale"] = static_cast<float>(backend_config.side_input_scale());
3264
- + attrs["activation_mode"] = static_cast<int32_t>(backend_config.activation_mode());
3265
- + attrs["leakyrelu_alpha"] = static_cast<float>(backend_config.leakyrelu_alpha());
3260
+ + attrs["conv_result_scale"] =
3261
+ + static_cast<float>(backend_config.conv_result_scale());
3262
+ + attrs["side_input_scale"] =
3263
+ + static_cast<float>(backend_config.side_input_scale());
3264
+ + attrs["activation_mode"] =
3265
+ + static_cast<int32_t>(backend_config.activation_mode());
3266
+ + attrs["leakyrelu_alpha"] =
3267
+ + static_cast<float>(backend_config.leakyrelu_alpha());
3266
3268
+
3267
3269
+ const Window& window = instr->window();
3268
3270
+ const ConvolutionDimensionNumbers& dnums =
@@ -3333,14 +3335,48 @@ index 4ba739dba..475cd1217 100644
3333
3335
+ instr->backend_config<xla::gpu::GpuBackendConfig>());
3334
3336
+ xla::gpu::GemmBackendConfig config = gpu_config.gemm_backend_config();
3335
3337
+ xla::gpu::GemmBackendConfig_Epilogue epilogue = config.epilogue();
3336
- + TF_ASSIGN_OR_RETURN(
3337
- + auto gemm_config,
3338
- + GemmConfig::For(static_cast<const HloInstruction*>(instr)));
3339
- + GemmConfig* gemm_config_ptr = new GemmConfig(gemm_config);
3340
3338
+ attrs["epilogue"] = static_cast<int32_t>(epilogue);
3341
- + // SYCL TODO:
3342
- + // gemm_config may be split into separate parameters and added to attrs later.
3343
- + attrs["gemm_config_ptr"] = reinterpret_cast<int64_t>(gemm_config_ptr);
3339
+ +
3340
+ + TF_ASSIGN_OR_RETURN(
3341
+ + auto gemm_config,
3342
+ + GemmConfig::For(static_cast<const HloInstruction*>(instr)));
3343
+ +
3344
+ + attrs["lhs_layout_dtype"] =
3345
+ + static_cast<int32_t>(gemm_config.lhs_layout.dtype);
3346
+ + attrs["lhs_order"] = static_cast<int32_t>(gemm_config.lhs_layout.order);
3347
+ + attrs["lhs_num_cols"] = gemm_config.lhs_layout.num_cols;
3348
+ + attrs["lhs_num_rows"] = gemm_config.lhs_layout.num_rows;
3349
+ + attrs["lhs_batch_stride"] = gemm_config.lhs_layout.batch_stride;
3350
+ + attrs["lhs_leading_dim_stride"] = gemm_config.lhs_layout.leading_dim_stride;
3351
+ +
3352
+ + attrs["rhs_layout_dtype"] =
3353
+ + static_cast<int32_t>(gemm_config.rhs_layout.dtype);
3354
+ + attrs["rhs_order"] = static_cast<int32_t>(gemm_config.rhs_layout.order);
3355
+ + attrs["rhs_num_cols"] = gemm_config.rhs_layout.num_cols;
3356
+ + attrs["rhs_num_rows"] = gemm_config.rhs_layout.num_rows;
3357
+ + attrs["rhs_batch_stride"] = gemm_config.rhs_layout.batch_stride;
3358
+ + attrs["rhs_leading_dim_stride"] = gemm_config.rhs_layout.leading_dim_stride;
3359
+ +
3360
+ + attrs["output_layout_dtype"] =
3361
+ + static_cast<int32_t>(gemm_config.output_layout.dtype);
3362
+ + attrs["output_order"] =
3363
+ + static_cast<int32_t>(gemm_config.output_layout.order);
3364
+ + attrs["output_num_cols"] = gemm_config.output_layout.num_cols;
3365
+ + attrs["output_num_rows"] = gemm_config.output_layout.num_rows;
3366
+ + attrs["output_batch_stride"] = gemm_config.output_layout.batch_stride;
3367
+ + attrs["output_leading_dim_stride"] =
3368
+ + gemm_config.output_layout.leading_dim_stride;
3369
+ +
3370
+ + attrs["batch_size"] =
3371
+ + static_cast<int64_t>(gemm_config.output_layout.batch_size);
3372
+ + attrs["alpha"] = static_cast<float>(gemm_config.alpha.real());
3373
+ + attrs["beta"] = static_cast<float>(gemm_config.beta);
3374
+ +
3375
+ + // config.algorithm is less than 0, thus 0 means no algorithm
3376
+ + if (gemm_config.algorithm.has_value()) {
3377
+ + attrs["algorithm"] = static_cast<int64_t>(gemm_config.algorithm.value());
3378
+ + } else
3379
+ + attrs["algorithm"] = static_cast<int64_t>(0);
3344
3380
+ } else {
3345
3381
+ return absl::InternalError("Unknown CustomCall To SYCL FFI Call");
3346
3382
+ }
@@ -3350,7 +3386,7 @@ index 4ba739dba..475cd1217 100644
3350
3386
3351
3387
absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3352
3388
const HloCustomCallInstruction* instr) {
3353
- @@ -1431,6 +1564 ,12 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3389
+ @@ -1431,6 +1600 ,12 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3354
3390
break;
3355
3391
3356
3392
case CustomCallApiVersion::API_VERSION_TYPED_FFI:
@@ -3363,7 +3399,7 @@ index 4ba739dba..475cd1217 100644
3363
3399
if (!backend_config_str.empty()) {
3364
3400
mlir::Attribute attr = mlir::parseAttribute(
3365
3401
backend_config_str, ir_emitter_context_->mlir_context());
3366
- @@ -1443,7 +1582 ,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3402
+ @@ -1443,7 +1618 ,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
3367
3403
"dictionary attribute");
3368
3404
}
3369
3405
break;
@@ -3372,7 +3408,7 @@ index 4ba739dba..475cd1217 100644
3372
3408
default:
3373
3409
return Internal("Unknown custom-call API version enum value: %d",
3374
3410
instr->api_version());
3375
- @@ -1484,7 +1623 ,7 @@ absl::Status IrEmitterUnnested::EmitFftThunk(const HloFftInstruction* instr) {
3411
+ @@ -1484,7 +1659 ,7 @@ absl::Status IrEmitterUnnested::EmitFftThunk(const HloFftInstruction* instr) {
3376
3412
return absl::OkStatus();
3377
3413
}
3378
3414
@@ -3381,7 +3417,7 @@ index 4ba739dba..475cd1217 100644
3381
3417
3382
3418
absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall(
3383
3419
const HloInstruction* instr) {
3384
- @@ -1564,7 +1703 ,7 @@ absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall(
3420
+ @@ -1564,7 +1739 ,7 @@ absl::Status IrEmitterUnnested::EmitTriangularSolveCustomCall(
3385
3421
}
3386
3422
return absl::OkStatus();
3387
3423
}
@@ -3390,15 +3426,15 @@ index 4ba739dba..475cd1217 100644
3390
3426
3391
3427
absl::Status IrEmitterUnnested::EmitTopKCustomCall(
3392
3428
const HloCustomCallInstruction* instr) {
3393
- @@ -2617,6 +2756 ,7 @@ absl::Status IrEmitterUnnested::EmitCopyStartThunk(
3429
+ @@ -2617,6 +2792 ,7 @@ absl::Status IrEmitterUnnested::EmitCopyStartThunk(
3394
3430
}
3395
3431
3396
3432
absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
3397
3433
+ #if 0
3398
3434
if (!instr->channel_id().has_value())
3399
3435
return absl::InternalError("Unknown send instruction channel id");
3400
3436
3401
- @@ -2669,12 +2809 ,14 @@ absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
3437
+ @@ -2669,12 +2845 ,14 @@ absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) {
3402
3438
*instr->channel_id(), send_recv_events_,
3403
3439
ConvertFrontendAttributes(instr->frontend_attributes()),
3404
3440
DeviceConstraint(instr)));
@@ -3414,7 +3450,7 @@ index 4ba739dba..475cd1217 100644
3414
3450
if (!instr->channel_id().has_value())
3415
3451
return absl::InternalError("Unknown send done instruction channel id");
3416
3452
3417
- @@ -2685,11 +2827 ,13 @@ absl::Status IrEmitterUnnested::EmitSendDoneThunk(
3453
+ @@ -2685,11 +2863 ,13 @@ absl::Status IrEmitterUnnested::EmitSendDoneThunk(
3418
3454
AddThunkToThunkSequence(std::make_unique<SendDoneThunk>(
3419
3455
Thunk::ThunkInfo::WithProfileAnnotation(instr), *instr->channel_id(),
3420
3456
send_recv_events_, DeviceConstraint(instr)));
@@ -3429,7 +3465,7 @@ index 4ba739dba..475cd1217 100644
3429
3465
if (!instr->channel_id().has_value())
3430
3466
return absl::InternalError("Unknown recv instruction channel id");
3431
3467
TF_RET_CHECK(instr->shape().IsTuple());
3432
- @@ -2744,11 +2888 ,13 @@ absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) {
3468
+ @@ -2744,11 +2924 ,13 @@ absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) {
3433
3469
ConvertFrontendAttributes(instr->frontend_attributes()),
3434
3470
DeviceConstraint(instr)));
3435
3471
@@ -3444,7 +3480,7 @@ index 4ba739dba..475cd1217 100644
3444
3480
if (!instr->channel_id().has_value())
3445
3481
return absl::InternalError("Unknown recv done instruction channel id");
3446
3482
3447
- @@ -2759,8 +2905 ,9 @@ absl::Status IrEmitterUnnested::EmitRecvDoneThunk(
3483
+ @@ -2759,8 +2941 ,9 @@ absl::Status IrEmitterUnnested::EmitRecvDoneThunk(
3448
3484
AddThunkToThunkSequence(std::make_unique<RecvDoneThunk>(
3449
3485
Thunk::ThunkInfo::WithProfileAnnotation(instr), *instr->channel_id(),
3450
3486
send_recv_events_, DeviceConstraint(instr)));
@@ -3455,7 +3491,7 @@ index 4ba739dba..475cd1217 100644
3455
3491
}
3456
3492
3457
3493
absl::Status IrEmitterUnnested::EmitHloInstruction(
3458
- @@ -2871,47 +3018 ,67 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
3494
+ @@ -2871,47 +3054 ,67 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
3459
3495
case HloOpcode::kCustomCall: {
3460
3496
auto* custom_call = Cast<HloCustomCallInstruction>(instr);
3461
3497
if (IsLegacyCublasMatmul(*instr)) {
@@ -4816,7 +4852,7 @@ index 3f9db4ea2..2aa8d4030 100644
4816
4852
absl::flat_hash_map<const stream_executor::Stream*,
4817
4853
std::unique_ptr<GenericConvRunner>>
4818
4854
diff --git a/xla/service/gpu/runtime/custom_call_thunk.cc b/xla/service/gpu/runtime/custom_call_thunk.cc
4819
- index 0eaf0aaf4..172ed916b 100644
4855
+ index 0eaf0aaf4..d9c4b733c 100644
4820
4856
--- a/xla/service/gpu/runtime/custom_call_thunk.cc
4821
4857
+++ b/xla/service/gpu/runtime/custom_call_thunk.cc
4822
4858
@@ -38,6 +38,7 @@ limitations under the License.
@@ -4836,28 +4872,7 @@ index 0eaf0aaf4..172ed916b 100644
4836
4872
#include "xla/stream_executor/gpu/gpu_stream.h"
4837
4873
#endif
4838
4874
4839
- @@ -79,6 +80,20 @@ CustomCallThunk::CustomCallThunk(ThunkInfo thunk_info, XLA_FFI_Handler* handler,
4840
- attributes_(std::move(attributes)),
4841
- called_computation_(called_computation) {}
4842
-
4843
- + #ifdef TENSORFLOW_USE_SYCL
4844
- + CustomCallThunk::~CustomCallThunk(){
4845
- + if(attributes_.find("gemm_config_ptr") != attributes_.end()){
4846
- + GemmConfig* gemm_config_ptr =
4847
- + reinterpret_cast<GemmConfig*>(
4848
- + std::get<int64_t>(
4849
- + std::get<std::variant<int32_t, int64_t, float>>(attributes_["gemm_config_ptr"])));
4850
- + if(gemm_config_ptr != nullptr){
4851
- + delete gemm_config_ptr;
4852
- + }
4853
- + }
4854
- + }
4855
- + #endif
4856
- +
4857
- absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4858
- // gpu_stream is CUstream or e.g. the equivalent type in ROCm.
4859
- std::vector<void*> buffers;
4860
- @@ -98,7 +113,7 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4875
+ @@ -98,7 +99,7 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4861
4876
}
4862
4877
}
4863
4878
@@ -4866,7 +4881,7 @@ index 0eaf0aaf4..172ed916b 100644
4866
4881
auto gpu_stream = se::gpu::AsGpuStreamValue(params.stream);
4867
4882
XlaCustomCallStatus custom_call_status;
4868
4883
call_target_(gpu_stream, buffers.data(), opaque_.data(), opaque_.size(),
4869
- @@ -109,11 +124 ,11 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4884
+ @@ -109,11 +110 ,11 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
4870
4885
} else {
4871
4886
return absl::OkStatus();
4872
4887
}
@@ -4881,7 +4896,7 @@ index 0eaf0aaf4..172ed916b 100644
4881
4896
4882
4897
absl::Status CustomCallThunk::ExecuteFfiHandler(const ExecuteParams& params) {
4883
4898
diff --git a/xla/service/gpu/runtime/custom_call_thunk.h b/xla/service/gpu/runtime/custom_call_thunk.h
4884
- index 02679d2e0..2bd5a73c6 100644
4899
+ index 02679d2e0..1bcb07264 100644
4885
4900
--- a/xla/service/gpu/runtime/custom_call_thunk.h
4886
4901
+++ b/xla/service/gpu/runtime/custom_call_thunk.h
4887
4902
@@ -32,7 +32,7 @@ limitations under the License.
@@ -4908,17 +4923,6 @@ index 02679d2e0..2bd5a73c6 100644
4908
4923
4909
4924
using CustomCallTarget = std::function<void(Stream, void**, const char*,
4910
4925
size_t, XlaCustomCallStatus*)>;
4911
- @@ -91,6 +91,10 @@ class CustomCallThunk : public Thunk {
4912
- const std::vector<std::optional<Slice>>& results() const { return results_; }
4913
- absl::string_view opaque() const { return opaque_; }
4914
-
4915
- + #ifdef TENSORFLOW_USE_SYCL
4916
- + ~CustomCallThunk();
4917
- + #endif
4918
- +
4919
- private:
4920
- absl::Status ExecuteCustomCall(const ExecuteParams& params);
4921
- absl::Status ExecuteFfiHandler(const ExecuteParams& params);
4922
4926
diff --git a/xla/service/gpu/runtime/fft_thunk.cc b/xla/service/gpu/runtime/fft_thunk.cc
4923
4927
index 728c36752..fccde5793 100644
4924
4928
--- a/xla/service/gpu/runtime/fft_thunk.cc
0 commit comments