@@ -18,7 +18,7 @@ pub(crate) fn handle_gpu_code<'ll>(
1818 // The offload memory transfer type for each kernel
1919 let mut memtransfer_types = vec ! [ ] ;
2020 let mut region_ids = vec ! [ ] ;
21- let offload_entry_ty = add_tgt_offload_entry ( & cx) ;
21+ let offload_entry_ty = TgtOffloadEntry :: new_decl ( & cx) ;
2222 for num in 0 ..9 {
2323 let kernel = cx. get_function ( & format ! ( "kernel_{num}" ) ) ;
2424 if let Some ( kernel) = kernel {
@@ -52,7 +52,6 @@ fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm
5252// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be
5353// offloaded was defined.
5454fn generate_at_one < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Value {
55- // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
5655 let unknown_txt = ";unknown;unknown;0;0;;" ;
5756 let c_entry_name = CString :: new ( unknown_txt) . unwrap ( ) ;
5857 let c_val = c_entry_name. as_bytes_with_nul ( ) ;
@@ -77,15 +76,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
7776 at_one
7877}
7978
80- pub ( crate ) fn add_tgt_offload_entry < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Type {
81- let offload_entry_ty = cx. type_named_struct ( "struct.__tgt_offload_entry" ) ;
82- let tptr = cx. type_ptr ( ) ;
83- let ti64 = cx. type_i64 ( ) ;
84- let ti32 = cx. type_i32 ( ) ;
85- let ti16 = cx. type_i16 ( ) ;
86- // For each kernel to run on the gpu, we will later generate one entry of this type.
87- // copied from LLVM
88- // typedef struct {
79+ struct TgtOffloadEntry {
8980 // uint64_t Reserved;
9081 // uint16_t Version;
9182 // uint16_t Kind;
@@ -95,21 +86,40 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty
9586 // uint64_t Size; Size of the entry info (0 if it is a function)
9687 // uint64_t Data;
9788 // void *AuxAddr;
98- // } __tgt_offload_entry;
99- let entry_elements = vec ! [ ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr] ;
100- cx. set_struct_body ( offload_entry_ty, & entry_elements, false ) ;
101- offload_entry_ty
10289}
10390
104- fn gen_tgt_kernel_global < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Type {
105- let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
106- let tptr = cx. type_ptr ( ) ;
107- let ti64 = cx. type_i64 ( ) ;
108- let ti32 = cx. type_i32 ( ) ;
109- let tarr = cx. type_array ( ti32, 3 ) ;
91+ impl TgtOffloadEntry {
92+ pub ( crate ) fn new_decl < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Type {
93+ let offload_entry_ty = cx. type_named_struct ( "struct.__tgt_offload_entry" ) ;
94+ let tptr = cx. type_ptr ( ) ;
95+ let ti64 = cx. type_i64 ( ) ;
96+ let ti32 = cx. type_i32 ( ) ;
97+ let ti16 = cx. type_i16 ( ) ;
98+ // For each kernel to run on the gpu, we will later generate one entry of this type.
99+ // copied from LLVM
100+ let entry_elements = vec ! [ ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr] ;
101+ cx. set_struct_body ( offload_entry_ty, & entry_elements, false ) ;
102+ offload_entry_ty
103+ }
104+
105+ fn new < ' ll > (
106+ cx : & ' ll SimpleCx < ' _ > ,
107+ region_id : & ' ll Value ,
108+ llglobal : & ' ll Value ,
109+ ) -> [ & ' ll Value ; 9 ] {
110+ let reserved = cx. get_const_i64 ( 0 ) ;
111+ let version = cx. get_const_i16 ( 1 ) ;
112+ let kind = cx. get_const_i16 ( 1 ) ;
113+ let flags = cx. get_const_i32 ( 0 ) ;
114+ let size = cx. get_const_i64 ( 0 ) ;
115+ let data = cx. get_const_i64 ( 0 ) ;
116+ let aux_addr = cx. const_null ( cx. type_ptr ( ) ) ;
117+ [ reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
118+ }
119+ }
110120
111- // Taken from the LLVM APITypes.h declaration:
112- // struct KernelArgsTy {
121+ // Taken from the LLVM APITypes.h declaration:
122+ struct KernelArgsTy {
113123 // uint32_t Version = 0; // Version of this struct for ABI compatibility.
114124 // uint32_t NumArgs = 0; // Number of arguments in each input pointer.
115125 // void **ArgBasePtrs =
@@ -120,8 +130,8 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
120130 // void **ArgNames = nullptr; // Name of the data for debugging, possibly null.
121131 // void **ArgMappers = nullptr; // User-defined mappers, possibly null.
122132 // uint64_t Tripcount =
123- // 0; // Tripcount for the teams / distribute loop, 0 otherwise.
124- // struct {
133+ // 0; // Tripcount for the teams / distribute loop, 0 otherwise.
134+ // struct {
125135 // uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
126136 // uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
127137 // uint64_t Unused : 62;
@@ -131,12 +141,54 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
131141 // // The number of threads (for x,y,z dimension).
132142 // uint32_t ThreadLimit[3] = {0, 0, 0};
133143 // uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested.
134- //};
135- let kernel_elements =
136- vec ! [ ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32] ;
144+ }
145+
146+ impl KernelArgsTy {
147+ const OFFLOAD_VERSION : u64 = 3 ;
148+ const FLAGS : u64 = 0 ;
149+ const TRIPCOUNT : u64 = 0 ;
150+ fn new_decl < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll Type {
151+ let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
152+ let tptr = cx. type_ptr ( ) ;
153+ let ti64 = cx. type_i64 ( ) ;
154+ let ti32 = cx. type_i32 ( ) ;
155+ let tarr = cx. type_array ( ti32, 3 ) ;
156+
157+ let kernel_elements =
158+ vec ! [ ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32] ;
159+
160+ cx. set_struct_body ( kernel_arguments_ty, & kernel_elements, false ) ;
161+ kernel_arguments_ty
162+ }
137163
138- cx. set_struct_body ( kernel_arguments_ty, & kernel_elements, false ) ;
139- kernel_arguments_ty
164+ fn new < ' ll > (
165+ cx : & ' ll SimpleCx < ' _ > ,
166+ num_args : u64 ,
167+ memtransfer_types : & [ & ' ll Value ] ,
168+ geps : [ & ' ll Value ; 3 ] ,
169+ ) -> [ ( Align , & ' ll Value ) ; 13 ] {
170+ let four = Align :: from_bytes ( 4 ) . expect ( "4 Byte alignment should work" ) ;
171+ let eight = Align :: EIGHT ;
172+
173+ let ti32 = cx. type_i32 ( ) ;
174+ let ci32_0 = cx. get_const_i32 ( 0 ) ;
175+ [
176+ ( four, cx. get_const_i32 ( KernelArgsTy :: OFFLOAD_VERSION ) ) ,
177+ ( four, cx. get_const_i32 ( num_args) ) ,
178+ ( eight, geps[ 0 ] ) ,
179+ ( eight, geps[ 1 ] ) ,
180+ ( eight, geps[ 2 ] ) ,
181+ ( eight, memtransfer_types[ 0 ] ) ,
182+ // The next two are debug infos. FIXME(offload): set them
183+ ( eight, cx. const_null ( cx. type_ptr ( ) ) ) , // dbg
184+ ( eight, cx. const_null ( cx. type_ptr ( ) ) ) , // dbg
185+ ( eight, cx. get_const_i64 ( KernelArgsTy :: TRIPCOUNT ) ) ,
186+ ( eight, cx. get_const_i64 ( KernelArgsTy :: FLAGS ) ) ,
187+ ( four, cx. const_array ( ti32, & [ cx. get_const_i32 ( 2097152 ) , ci32_0, ci32_0] ) ) ,
188+ ( four, cx. const_array ( ti32, & [ cx. get_const_i32 ( 256 ) , ci32_0, ci32_0] ) ) ,
189+ ( four, cx. get_const_i32 ( 0 ) ) ,
190+ ]
191+ }
140192}
141193
142194fn gen_tgt_data_mappers < ' ll > (
@@ -245,19 +297,10 @@ fn gen_define_handling<'ll>(
245297 let llglobal = add_unnamed_global ( & cx, & offload_entry_name, initializer, InternalLinkage ) ;
246298 llvm:: set_alignment ( llglobal, Align :: ONE ) ;
247299 llvm:: set_section ( llglobal, c".llvm.rodata.offloading" ) ;
248-
249- // Not actively used yet, for calling real kernels
250300 let name = format ! ( ".offloading.entry.kernel_{num}" ) ;
251301
252302 // See the __tgt_offload_entry documentation above.
253- let reserved = cx. get_const_i64 ( 0 ) ;
254- let version = cx. get_const_i16 ( 1 ) ;
255- let kind = cx. get_const_i16 ( 1 ) ;
256- let flags = cx. get_const_i32 ( 0 ) ;
257- let size = cx. get_const_i64 ( 0 ) ;
258- let data = cx. get_const_i64 ( 0 ) ;
259- let aux_addr = cx. const_null ( cx. type_ptr ( ) ) ;
260- let elems = vec ! [ reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr] ;
303+ let elems = TgtOffloadEntry :: new ( & cx, region_id, llglobal) ;
261304
262305 let initializer = crate :: common:: named_struct ( offload_entry_ty, & elems) ;
263306 let c_name = CString :: new ( name) . unwrap ( ) ;
@@ -319,7 +362,7 @@ fn gen_call_handling<'ll>(
319362 let tgt_bin_desc = cx. type_named_struct ( "struct.__tgt_bin_desc" ) ;
320363 cx. set_struct_body ( tgt_bin_desc, & tgt_bin_desc_ty, false ) ;
321364
322- let tgt_kernel_decl = gen_tgt_kernel_global ( & cx) ;
365+ let tgt_kernel_decl = KernelArgsTy :: new_decl ( & cx) ;
323366 let ( begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers ( & cx) ;
324367
325368 let main_fn = cx. get_function ( "main" ) ;
@@ -407,19 +450,19 @@ fn gen_call_handling<'ll>(
407450 a1 : & ' ll Value ,
408451 a2 : & ' ll Value ,
409452 a4 : & ' ll Value ,
410- ) -> ( & ' ll Value , & ' ll Value , & ' ll Value ) {
453+ ) -> [ & ' ll Value ; 3 ] {
411454 let i32_0 = cx. get_const_i32 ( 0 ) ;
412455
413456 let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
414457 let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
415458 let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
416- ( gep1, gep2, gep3)
459+ [ gep1, gep2, gep3]
417460 }
418461
419462 fn generate_mapper_call < ' a , ' ll > (
420463 builder : & mut SBuilder < ' a , ' ll > ,
421464 cx : & ' ll SimpleCx < ' ll > ,
422- geps : ( & ' ll Value , & ' ll Value , & ' ll Value ) ,
465+ geps : [ & ' ll Value ; 3 ] ,
423466 o_type : & ' ll Value ,
424467 fn_to_call : & ' ll Value ,
425468 fn_ty : & ' ll Type ,
@@ -430,7 +473,7 @@ fn gen_call_handling<'ll>(
430473 let i64_max = cx. get_const_i64 ( u64:: MAX ) ;
431474 let num_args = cx. get_const_i32 ( num_args) ;
432475 let args =
433- vec ! [ s_ident_t, i64_max, num_args, geps. 0 , geps. 1 , geps. 2 , o_type, nullptr, nullptr] ;
476+ vec ! [ s_ident_t, i64_max, num_args, geps[ 0 ] , geps[ 1 ] , geps[ 2 ] , o_type, nullptr, nullptr] ;
434477 builder. call ( fn_ty, fn_to_call, & args, None ) ;
435478 }
436479
@@ -439,36 +482,20 @@ fn gen_call_handling<'ll>(
439482 let o = memtransfer_types[ 0 ] ;
440483 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
441484 generate_mapper_call ( & mut builder, & cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t) ;
485+ let values = KernelArgsTy :: new ( & cx, num_args, memtransfer_types, geps) ;
442486
443487 // Step 3)
444- let mut values = vec ! [ ] ;
445- let offload_version = cx. get_const_i32 ( 3 ) ;
446- values. push ( ( 4 , offload_version) ) ;
447- values. push ( ( 4 , cx. get_const_i32 ( num_args) ) ) ;
448- values. push ( ( 8 , geps. 0 ) ) ;
449- values. push ( ( 8 , geps. 1 ) ) ;
450- values. push ( ( 8 , geps. 2 ) ) ;
451- values. push ( ( 8 , memtransfer_types[ 0 ] ) ) ;
452- // The next two are debug infos. FIXME(offload) set them
453- values. push ( ( 8 , cx. const_null ( cx. type_ptr ( ) ) ) ) ;
454- values. push ( ( 8 , cx. const_null ( cx. type_ptr ( ) ) ) ) ;
455- values. push ( ( 8 , cx. get_const_i64 ( 0 ) ) ) ;
456- values. push ( ( 8 , cx. get_const_i64 ( 0 ) ) ) ;
457- let ti32 = cx. type_i32 ( ) ;
458- let ci32_0 = cx. get_const_i32 ( 0 ) ;
459- values. push ( ( 4 , cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 2097152 ) , ci32_0, ci32_0] ) ) ) ;
460- values. push ( ( 4 , cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 256 ) , ci32_0, ci32_0] ) ) ) ;
461- values. push ( ( 4 , cx. get_const_i32 ( 0 ) ) ) ;
462-
488+ // Here we fill the KernelArgsTy, see the documentation above
463489 for ( i, value) in values. iter ( ) . enumerate ( ) {
464490 let ptr = builder. inbounds_gep ( tgt_kernel_decl, a5, & [ i32_0, cx. get_const_i32 ( i as u64 ) ] ) ;
465- builder. store ( value. 1 , ptr, Align :: from_bytes ( value. 0 ) . unwrap ( ) ) ;
491+ builder. store ( value. 1 , ptr, value. 0 ) ;
466492 }
467493
468494 let args = vec ! [
469495 s_ident_t,
470- // MAX == -1
471- cx. get_const_i64( u64 :: MAX ) ,
496+ // FIXME(offload) give users a way to select which GPU to use.
497+ cx. get_const_i64( u64 :: MAX ) , // MAX == -1.
498+ // FIXME(offload): Don't hardcode the numbers of threads in the future.
472499 cx. get_const_i32( 2097152 ) ,
473500 cx. get_const_i32( 256 ) ,
474501 region_ids[ 0 ] ,
@@ -483,19 +510,14 @@ fn gen_call_handling<'ll>(
483510 }
484511
485512 // Step 4)
486- //unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
487-
488513 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
489514 generate_mapper_call ( & mut builder, & cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t) ;
490515
491516 builder. call ( mapper_fn_ty, unregister_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
492517
493518 drop ( builder) ;
519+ // FIXME(offload) The issue is that we right now add a call to the gpu version of the function,
520+ // and then delete the call to the CPU version. In the future, we should use an intrinsic which
521+ // directly resolves to a call to the GPU version.
494522 unsafe { llvm:: LLVMDeleteFunction ( called) } ;
495-
496- // With this we generated the following begin and end mappers. We could easily generate the
497- // update mapper in an update.
498- // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
499- // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
500- // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
501523}
0 commit comments