Skip to content

Commit d892a83

Browse files
authored
Merge pull request #654 from Jittor/hyx
merge hw backend
2 parents b79ac22 + 4017b16 commit d892a83

File tree

127 files changed

+11133
-2421
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

127 files changed

+11133
-2421
lines changed

python/jittor/compile_extern.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,26 @@ def setup_nccl():
611611
nccl_ops = nccl.ops
612612
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
613613

614+
def setup_hccl():
615+
global hccl_ops
616+
617+
hccl_src_dir = os.path.join(jittor_path, "extern", "acl", "hccl")
618+
hccl_src_files = []
619+
for r, _, f in os.walk(hccl_src_dir):
620+
for fname in f:
621+
hccl_src_files.append(os.path.join(r, fname))
622+
623+
hccl_include_path = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/include/hccl")
624+
hccl_lib_name = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/lib64/libhccl.so")
625+
ctypes.CDLL(hccl_lib_name, dlopen_flags)
626+
627+
hccl = compile_custom_ops(hccl_src_files,
628+
extra_flags=f" -I\"{hccl_include_path}\" {mpi_compile_flags} ",
629+
return_module=True, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW,
630+
gen_name_="jittor_hccl_core")
631+
hccl_ops = hccl.ops
632+
LOG.vv("Get hccl_ops: "+str(dir(hccl_ops)))
633+
614634
def manual_link(flags):
615635
lib_dirs = []
616636
libs = []
@@ -708,8 +728,14 @@ def inner(self, *args, **kw):
708728
setup_mpi()
709729
rank = mpi.world_rank() if in_mpi else 0
710730
world_size = mpi.world_size() if in_mpi else 1
731+
# if has_acl:
732+
# setup_hccl()
733+
# elif has_cuda:
734+
# setup_nccl()
735+
# setup_cutt()
736+
# setup_cutlass()
737+
711738
setup_nccl()
712-
713739
setup_cutt()
714740
setup_cutlass()
715741

python/jittor/compiler.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,7 +1188,22 @@ def fix_cl_flags(cmd):
11881188

11891189
# build cache_compile
11901190
cc_flags += f" -I\"{os.path.join(jittor_path, 'src')}\" "
1191+
cc_flags += f" -I\"{os.path.join(jittor_path, 'extern')}\" "
1192+
1193+
ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME')
1194+
1195+
if ascend_toolkit_home:
1196+
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include')}\" "
1197+
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/acl')}\" "
1198+
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnn')}\" "
1199+
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnnop')}\" "
1200+
cc_flags += f" -L\"{os.path.join(ascend_toolkit_home, 'lib64')}\" "
1201+
cc_flags += " -llibascendcl "
1202+
cc_flags += " -llibnnopbase "
1203+
cc_flags += " -llibopapi "
1204+
11911205
cc_flags += py_include
1206+
11921207
check_cache_compile()
11931208
LOG.v(f"Get cache_compile: {jit_utils.cc}")
11941209

0 commit comments

Comments
 (0)