@@ -154,6 +154,29 @@ def sniff_compiler(exe, comm=mpi.COMM_WORLD):
154154 return comm .bcast (compiler , 0 )
155155
156156
157+ def _check_src_hashes (comm , global_kernel ):
158+ hsh = md5 (str (global_kernel .cache_key [1 :]).encode ())
159+ basename = hsh .hexdigest ()
160+ dirpart , basename = basename [:2 ], basename [2 :]
161+ cachedir = configuration ["cache_dir" ]
162+ cachedir = os .path .join (cachedir , dirpart )
163+
164+ if configuration ["check_src_hashes" ] or configuration ["debug" ]:
165+ matching = comm .allreduce (basename , op = _check_op )
166+ if matching != basename :
167+ # Dump all src code to disk for debugging
168+ output = os .path .join (cachedir , "mismatching-kernels" )
169+ srcfile = os .path .join (output , "src-rank%d.c" % comm .rank )
170+ if comm .rank == 0 :
171+ os .makedirs (output , exist_ok = True )
172+ comm .barrier ()
173+ with open (srcfile , "w" ) as f :
174+ f .write (global_kernel .code_to_compile )
175+ comm .barrier ()
176+ raise CompilationError ("Generated code differs across ranks"
177+ f" (see output in { output } )" )
178+
179+
157180class Compiler (ABC ):
158181 """A compiler for shared libraries.
159182
@@ -324,19 +347,8 @@ def get_so(self, jitmodule, extension):
324347 # atomically (avoiding races).
325348 tmpname = os .path .join (cachedir , "%s_p%d.so.tmp" % (basename , pid ))
326349
327- if configuration ['check_src_hashes' ] or configuration ['debug' ]:
328- matching = self .comm .allreduce (basename , op = _check_op )
329- if matching != basename :
330- # Dump all src code to disk for debugging
331- output = os .path .join (configuration ["cache_dir" ], "mismatching-kernels" )
332- srcfile = os .path .join (output , "src-rank%d.c" % self .comm .rank )
333- if self .comm .rank == 0 :
334- os .makedirs (output , exist_ok = True )
335- self .comm .barrier ()
336- with open (srcfile , "w" ) as f :
337- f .write (jitmodule .code_to_compile )
338- self .comm .barrier ()
339- raise CompilationError ("Generated code differs across ranks (see output in %s)" % output )
350+ _check_src_hashes (self .comm , jitmodule )
351+
340352 try :
341353 # Are we in the cache?
342354 return ctypes .CDLL (soname )
@@ -652,3 +664,81 @@ def clear_cache(prompt=False):
652664 shutil .rmtree (cachedir , ignore_errors = True )
653665 else :
654666 print ("Not removing cached libraries" )
667+
668+
669+ def _get_code_to_compile (comm , global_kernel ):
670+ # Determine cache key
671+ hsh = md5 (str (global_kernel .cache_key [1 :]).encode ())
672+ basename = hsh .hexdigest ()
673+ cachedir = configuration ["cache_dir" ]
674+ dirpart , basename = basename [:2 ], basename [2 :]
675+ cachedir = os .path .join (cachedir , dirpart )
676+ cname = os .path .join (cachedir , f"{ basename } _code.cu" )
677+
678+ _check_src_hashes (comm , global_kernel )
679+
680+ if os .path .isfile (cname ):
681+ # Are we in the cache?
682+ with open (cname , "r" ) as f :
683+ code_to_compile = f .read ()
684+ else :
685+ # No, let"s go ahead and build
686+ if comm .rank == 0 :
687+ # No need to do this on all ranks
688+ os .makedirs (cachedir , exist_ok = True )
689+ with progress (INFO , "Compiling wrapper" ):
690+ # make sure that compiles successfully before writing to file
691+ code_to_compile = global_kernel .code_to_compile
692+ with open (cname , "w" ) as f :
693+ f .write (code_to_compile )
694+ comm .barrier ()
695+
696+ return code_to_compile
697+
698+
699+ @mpi .collective
700+ def get_prepared_cuda_function (comm , global_kernel ):
701+ from pycuda .compiler import SourceModule
702+
703+ # Determine cache key
704+ hsh = md5 (str (global_kernel .cache_key [1 :]).encode ())
705+ basename = hsh .hexdigest ()
706+ cachedir = configuration ["cache_dir" ]
707+ dirpart , basename = basename [:2 ], basename [2 :]
708+ cachedir = os .path .join (cachedir , dirpart )
709+
710+ nvcc_opts = ["-use_fast_math" , "-w" ]
711+
712+ code_to_compile = _get_code_to_compile (comm , global_kernel )
713+ source_module = SourceModule (code_to_compile , options = nvcc_opts ,
714+ cache_dir = cachedir )
715+
716+ cu_func = source_module .get_function (global_kernel .name )
717+
718+ type_map = {ctypes .c_void_p : "P" , ctypes .c_int : "i" }
719+ argtypes = "" .join (type_map [t ] for t in global_kernel .argtypes )
720+ cu_func .prepare (argtypes )
721+
722+ return cu_func
723+
724+
725+ @mpi .collective
726+ def get_opencl_kernel (comm , global_kernel ):
727+ import pyopencl as cl
728+ from pyop2 .backends .opencl import opencl_backend
729+ cl_ctx = opencl_backend .context
730+
731+ # Determine cache key
732+ hsh = md5 (str (global_kernel .cache_key [1 :]).encode ())
733+ basename = hsh .hexdigest ()
734+ cachedir = configuration ["cache_dir" ]
735+ dirpart , basename = basename [:2 ], basename [2 :]
736+ cachedir = os .path .join (cachedir , dirpart )
737+
738+ code_to_compile = _get_code_to_compile (comm , global_kernel )
739+
740+ prg = cl .Program (cl_ctx , code_to_compile ).build (options = [],
741+ cache_dir = cachedir )
742+
743+ cl_knl = cl .Kernel (prg , global_kernel .name )
744+ return cl_knl
0 commit comments