1
+ #include < sstream>
1
2
#include < string>
2
3
3
4
#include " exla_client.h"
4
5
#include " exla_cuda.h"
5
6
#include " exla_log_sink.h"
6
7
#include " exla_mlir.h"
7
8
#include " exla_nif_util.h"
9
+ #include " ipc.h"
8
10
#include " mhlo/IR/hlo_ops.h"
9
11
#include " mlir/Dialect/Func/IR/FuncOps.h"
10
12
#include " stablehlo/dialect/ChloOps.h"
@@ -449,34 +451,60 @@ ERL_NIF_TERM get_buffer_device_pointer(ErlNifEnv* env, int argc, const ERL_NIF_T
449
451
return exla::nif::error (env, " Unable to get device pointer kind." );
450
452
}
451
453
454
+ EXLA_ASSIGN_OR_RETURN_NIF (unsigned long device_size, (*buffer)->GetOnDeviceSizeInBytes (), env);
455
+
452
456
EXLA_ASSIGN_OR_RETURN_NIF (std::uintptr_t ptr,
453
457
(*buffer)->GetDevicePointer ((*client)->client ()), env);
454
458
455
- std::vector< unsigned char > pointer_vec ;
459
+ ERL_NIF_TERM out_term ;
456
460
if (pointer_kind == " local" ) {
457
- unsigned char * bytePtr = reinterpret_cast <unsigned char *>(&ptr);
458
- for (size_t i = 0 ; i < sizeof (void *); i++) {
459
- pointer_vec.push_back (bytePtr[i]);
461
+ ERL_NIF_TERM ptr_term = enif_make_ulong (env, ptr);
462
+ ERL_NIF_TERM size_term = enif_make_ulong (env, device_size);
463
+ out_term = enif_make_tuple2 (env, ptr_term, size_term);
464
+ } else if (pointer_kind == " host_ipc" ) {
465
+ std::ostringstream handle_name_stream;
466
+ handle_name_stream << " exla:ipc:" << device_size << " :" << ptr;
467
+ std::string handle_name = handle_name_stream.str ();
468
+ int fd = get_ipc_handle ((char *)handle_name.c_str (), device_size);
469
+
470
+ if (fd == -1 ) {
471
+ return exla::nif::error (env, " Unable to get IPC handle" );
472
+ }
473
+
474
+ void * ipc_ptr = open_ipc_handle (fd, device_size);
475
+ if (ipc_ptr == nullptr ) {
476
+ return exla::nif::error (env, " Unable to open IPC handle" );
460
477
}
478
+
479
+ memcpy (ipc_ptr, (void *)ptr, device_size);
480
+
481
+ ErlNifBinary handle_name_bin;
482
+ enif_alloc_binary (handle_name.size (), &handle_name_bin);
483
+ for (int i = 0 ; i < handle_name.size (); i++) {
484
+ handle_name_bin.data [i] = handle_name[i];
485
+ }
486
+ ERL_NIF_TERM handle_name_term = enif_make_binary (env, &handle_name_bin);
487
+ ERL_NIF_TERM size_term = enif_make_uint64 (env, device_size);
488
+ ERL_NIF_TERM fd_term = enif_make_int (env, fd);
489
+ out_term = enif_make_tuple3 (env, handle_name_term, fd_term, size_term);
461
490
} else if (pointer_kind == " cuda_ipc" ) {
462
491
auto result = get_cuda_ipc_handle (ptr);
463
492
if (result.second ) {
464
493
return exla::nif::error (env, " Unable to get cuda IPC handle" );
465
494
}
466
- pointer_vec = result.first ;
467
- }
495
+ auto pointer_vec = result.first ;
468
496
469
- EXLA_ASSIGN_OR_RETURN_NIF (unsigned long device_size, (*buffer)->GetOnDeviceSizeInBytes (), env);
470
-
471
- ERL_NIF_TERM handle_list[pointer_vec.size ()];
472
- for (int i = 0 ; i < pointer_vec.size (); i++) {
473
- handle_list[i] = enif_make_uint (env, pointer_vec[i]);
497
+ ErlNifBinary handle_bin;
498
+ enif_alloc_binary (pointer_vec.size (), &handle_bin);
499
+ for (int i = 0 ; i < pointer_vec.size (); i++) {
500
+ handle_bin.data [i] = pointer_vec[i];
501
+ }
502
+ ERL_NIF_TERM handle_term = enif_make_binary (env, &handle_bin);
503
+ ERL_NIF_TERM size_term = enif_make_uint64 (env, device_size);
504
+ out_term = enif_make_tuple2 (env, handle_term, size_term);
474
505
}
475
506
476
- ERL_NIF_TERM handle_list_term = enif_make_list_from_array (env, handle_list, pointer_vec.size ());
477
- ERL_NIF_TERM device_size_term = enif_make_uint64 (env, device_size);
478
-
479
- return exla::nif::ok (env, enif_make_tuple2 (env, handle_list_term, device_size_term));
507
+ return exla::nif::ok (env, out_term);
480
508
}
481
509
482
510
ERL_NIF_TERM create_buffer_from_device_pointer (ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
@@ -485,40 +513,68 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E
485
513
}
486
514
487
515
exla::ExlaClient** client;
488
- std::vector<int64_t > pointer_vec;
516
+ ErlNifBinary cuda_ipc_handle_bin;
517
+ int cuda_ipc_handle_size = 0 ;
489
518
xla::Shape shape;
490
519
int device_id;
491
520
std::string pointer_kind;
521
+ void * ptr;
522
+ int fd = -1 ;
523
+ std::string memname;
492
524
493
525
if (!exla::nif::get<exla::ExlaClient*>(env, argv[0 ], client)) {
494
526
return exla::nif::error (env, " Unable to get client." );
495
527
}
496
- if (!exla::nif::get_list (env, argv[1 ], pointer_vec)) {
497
- return exla::nif::error (env, " Unable to get device pointer." );
498
- }
499
- if (!exla::nif::get_atom (env, argv[2 ], pointer_kind)) {
528
+ if (!exla::nif::get_atom (env, argv[1 ], pointer_kind)) {
500
529
return exla::nif::error (env, " Unable to get device pointer kind." );
501
530
}
531
+
532
+ if (pointer_kind == " cuda_ipc" ) {
533
+ if (!enif_inspect_binary (env, argv[2 ], &cuda_ipc_handle_bin)) {
534
+ return exla::nif::error (env, " Unable to get CUDA IPC handle." );
535
+ }
536
+ } else if (pointer_kind == " host_ipc" ) {
537
+ const ERL_NIF_TERM* tuple;
538
+ int arity;
539
+ if (
540
+ !enif_get_tuple (env, argv[2 ], &arity, &tuple) ||
541
+ (arity != 2 ) ||
542
+ !exla::nif::get (env, tuple[0 ], &fd) ||
543
+ (fd == -1 ) ||
544
+ !exla::nif::get (env, tuple[1 ], memname)) {
545
+ return exla::nif::error (env, " Unable to get IPC handle." );
546
+ }
547
+ } else if (pointer_kind == " local" ) {
548
+ int64_t ptr_int;
549
+ if (!exla::nif::get (env, argv[2 ], &ptr_int)) {
550
+ return exla::nif::error (env, " Unable to get pointer." );
551
+ }
552
+
553
+ ptr = (void *)ptr_int;
554
+ }
555
+
502
556
if (!exla::nif::get_typespec_as_xla_shape (env, argv[3 ], &shape)) {
503
557
return exla::nif::error (env, " Unable to get shape." );
504
558
}
505
559
if (!exla::nif::get (env, argv[4 ], &device_id)) {
506
560
return exla::nif::error (env, " Unable to get device ordinal." );
507
561
}
508
562
509
- void * ptr;
510
- if (pointer_kind == " local" ) {
511
- if (pointer_vec.size () != sizeof (void *)) {
512
- // This helps prevent segfaults if someone passes an IPC handle instead of
513
- // a local pointer.
514
- return exla::nif::error (env, " Invalid pointer size for selected mode." );
515
- }
516
- unsigned char * bytePtr = reinterpret_cast <unsigned char *>(&ptr);
517
- for (size_t i = 0 ; i < sizeof (void *); i++) {
518
- bytePtr[i] = pointer_vec[i];
563
+ std::function<void ()> on_delete_callback = []() {};
564
+
565
+ if (pointer_kind == " host_ipc" ) {
566
+ size_t device_size = (size_t )xla::ShapeUtil::ByteSizeOf (shape);
567
+
568
+ ptr = open_ipc_handle (fd, device_size);
569
+ if (ptr == nullptr ) {
570
+ return exla::nif::error (env, " Unable to get pointer for IPC handle." );
519
571
}
572
+
573
+ on_delete_callback = [fd, memname, ptr, device_size]() {
574
+ close_ipc_handle (fd, ptr, (char *)memname.c_str (), device_size);
575
+ };
520
576
} else if (pointer_kind == " cuda_ipc" ) {
521
- auto result = get_pointer_for_ipc_handle (pointer_vec , device_id);
577
+ auto result = get_pointer_for_ipc_handle (cuda_ipc_handle_bin. data , cuda_ipc_handle_bin. size , device_id);
522
578
if (result.second ) {
523
579
return exla::nif::error (env, " Unable to get pointer for IPC handle." );
524
580
}
@@ -527,8 +583,8 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E
527
583
528
584
EXLA_ASSIGN_OR_RETURN_NIF (xla::PjRtDevice * device, (*client)->client ()->LookupDevice (xla::PjRtGlobalDeviceId (device_id)), env);
529
585
530
- std::function<void ()> on_delete_callback = []() {};
531
586
EXLA_ASSIGN_OR_RETURN_NIF (std::unique_ptr<xla::PjRtBuffer> buffer, (*client)->client ()->CreateViewOfDeviceBuffer (ptr, shape, device, on_delete_callback), env);
587
+
532
588
exla::ExlaBuffer* exla_buffer = new exla::ExlaBuffer (std::move (buffer));
533
589
return exla::nif::ok (env, exla::nif::make<exla::ExlaBuffer*>(env, exla_buffer));
534
590
}
0 commit comments