Skip to content

Commit c78db2a

Browse files
committed
enable cuda and acl
1 parent f8e44de commit c78db2a

File tree

7 files changed

+27
-24
lines changed

7 files changed

+27
-24
lines changed

python/jittor/compile_extern.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,8 +457,7 @@ def setup_cutt():
457457
def install_cutlass(root_folder):
458458
# Modified from: https://github.com/ap-hynninen/cutlass
459459
# url = "https://cloud.tsinghua.edu.cn/f/171e49e5825549548bc4/?dl=1"
460-
# url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/cutlass.zip"
461-
url = "https://cloud.tsinghua.edu.cn/f/171e49e5825549548bc4/?dl=1"
460+
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/cutlass.zip"
462461

463462
filename = "cutlass.zip"
464463
fullname = os.path.join(root_folder, filename)

python/jittor/compiler.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,20 +1186,22 @@ def fix_cl_flags(cmd):
11861186
ck_path = os.path.join(cache_path, "checkpoints")
11871187
make_cache_dir(ck_path)
11881188

1189-
1190-
ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME')
1191-
11921189
# build cache_compile
11931190
cc_flags += f" -I\"{os.path.join(jittor_path, 'src')}\" "
11941191
cc_flags += f" -I\"{os.path.join(jittor_path, 'extern')}\" "
1195-
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include')}\" "
1196-
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/acl')}\" "
1197-
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnn')}\" "
1198-
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnnop')}\" "
1199-
cc_flags += f" -L\"{os.path.join(ascend_toolkit_home, 'lib64')}\" "
1200-
cc_flags += " -llibascendcl "
1201-
cc_flags += " -llibnnopbase "
1202-
cc_flags += " -llibopapi "
1192+
1193+
ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME')
1194+
1195+
if ascend_toolkit_home:
1196+
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include')}\" "
1197+
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/acl')}\" "
1198+
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnn')}\" "
1199+
cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnnop')}\" "
1200+
cc_flags += f" -L\"{os.path.join(ascend_toolkit_home, 'lib64')}\" "
1201+
cc_flags += " -llibascendcl "
1202+
cc_flags += " -llibnnopbase "
1203+
cc_flags += " -llibopapi "
1204+
12031205
cc_flags += py_include
12041206

12051207
check_cache_compile()

python/jittor/extern/acl/acl_jittor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// ***************************************************************
77
#pragma once
88
#include "common.h"
9+
#include "aclnn/aclnn.h"
910
#include <acl/acl.h>
1011

1112
std::string acl_error_to_string(aclError error);

python/jittor/extern/acl/aclops/utils.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <Python.h>
66
#include <pystate.h>
77
#include "utils.h"
8+
#include "aclnn/aclnn.h"
89

910
namespace jittor
1011
{

python/jittor/extern/acl/aclops/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <Python.h>
77
#include <pystate.h>
88
#include "misc/nano_string.h"
9+
#include "aclnn/aclnn.h"
910

1011
namespace jittor
1112
{

python/jittor/src/common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include <memory>
99
#include <functional>
1010
#include "utils/log.h"
11-
#include "../extern/acl/aclnn/aclnn.h"
1211

1312
#define JIT_TEST(name) extern void jit_test_ ## name ()
1413
void expect_error(std::function<void()> func);

python/jittor/src/ops/fetch_op.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Init() {
4747
if (!get_device_count()) return;
4848
checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
4949
checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming));
50-
stream = aclstream;
50+
// stream = aclstream;
5151
}
5252
~Init() {
5353
if (!get_device_count()) return;
@@ -123,23 +123,23 @@ void FetchOp::run() {
123123
new (&allocation) Allocation(&cuda_dual_allocator, v->size);
124124
// mostly device to device
125125
#if IS_CUDA
126-
// checkCudaErrors(cudaMemcpyAsync(
127-
// allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDefault, stream));
128126
checkCudaErrors(cudaMemcpyAsync(
129-
allocation.ptr, v->size, v->mem_ptr, v->size, cudaMemcpyDefault, aclstream));
130-
checkCudaErrors(aclrtSynchronizeStream(aclstream));
127+
allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDefault, stream));
128+
// checkCudaErrors(cudaMemcpyAsync(
129+
// allocation.ptr, v->size, v->mem_ptr, v->size, cudaMemcpyDefault, aclstream));
130+
// checkCudaErrors(aclrtSynchronizeStream(aclstream));
131131
#else
132132
checkCudaErrors(cudaMemcpyAsync(
133133
allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDeviceToDevice, stream));
134134
#endif
135135
auto host_ptr = cuda_dual_allocator.get_dual_allocation(
136136
allocation.allocation).host_ptr;
137137
// device to host
138-
// checkCudaErrors(cudaMemcpyAsync(
139-
// host_ptr, allocation.ptr, v->size, cudaMemcpyDeviceToHost, stream));
140-
checkCudaErrors(aclrtMemcpyAsync(
141-
host_ptr, v->size, allocation.ptr, v->size, cudaMemcpyDeviceToHost, aclstream));
142-
checkCudaErrors(aclrtSynchronizeStream(aclstream));
138+
checkCudaErrors(cudaMemcpyAsync(
139+
host_ptr, allocation.ptr, v->size, cudaMemcpyDeviceToHost, stream));
140+
// checkCudaErrors(aclrtMemcpyAsync(
141+
// host_ptr, v->size, allocation.ptr, v->size, cudaMemcpyDeviceToHost, aclstream));
142+
// checkCudaErrors(aclrtSynchronizeStream(aclstream));
143143
allocation.ptr = host_ptr;
144144
has_cuda_memcpy = true;
145145
} else

0 commit comments

Comments
 (0)