Skip to content

Commit 92821f6

Browse files
committed
CANN: use env instead
1 parent c54a207 commit 92821f6

File tree

3 files changed

+179
-211
lines changed

3 files changed

+179
-211
lines changed

ggml/src/ggml-cann/aclnn_ops.h

Lines changed: 161 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -651,8 +651,126 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
651651
*/
652652
void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
653653

654+
using AnyAclResource = std::unique_ptr<void, std::function<void(void*)>>;
655+
656+
template<typename T>
657+
struct AclResourceTraits;
658+
template<>
659+
struct AclResourceTraits<aclTensor> {
660+
static void destroy(void* p) {
661+
ACL_CHECK(aclDestroyTensor(static_cast<aclTensor*>(p)));
662+
}
663+
};
664+
template<>
665+
struct AclResourceTraits<aclIntArray> {
666+
static void destroy(void* p) {
667+
ACL_CHECK(aclDestroyIntArray(static_cast<aclIntArray*>(p)));
668+
}
669+
};
670+
template<>
671+
struct AclResourceTraits<aclScalar> {
672+
static void destroy(void* p) {
673+
ACL_CHECK(aclDestroyScalar(static_cast<aclScalar*>(p)));
674+
}
675+
};
676+
template<>
677+
struct AclResourceTraits<aclTensorList> {
678+
static void destroy(void* p) {
679+
ACL_CHECK(aclDestroyTensorList(static_cast<aclTensorList*>(p)));
680+
}
681+
};
682+
683+
template<typename T>
684+
AnyAclResource make_acl_resource(T* ptr) {
685+
return AnyAclResource(
686+
static_cast<void*>(ptr),
687+
[](void* p) {
688+
AclResourceTraits<T>::destroy(p);
689+
}
690+
);
691+
}
692+
693+
template<typename... Args>
694+
void register_acl_resources(std::vector<AnyAclResource>& vec, Args*... args) {
695+
(vec.emplace_back(make_acl_resource(args)), ...);
696+
}
697+
698+
class aclnn_task : public cann_task {
699+
public:
700+
aclnn_task(aclnn_func_t aclnn_func, void * workspace_addr, uint64_t workspace_size, aclOpExecutor * executor,
701+
aclrtStream stream) :
702+
aclnn_func_(aclnn_func),
703+
workspace_addr_(workspace_addr),
704+
workspace_size_(workspace_size),
705+
executor_(executor),
706+
stream_(stream) {}
707+
virtual void run_task() override {
708+
ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_));
709+
}
710+
private:
711+
aclnn_func_t aclnn_func_;
712+
void * workspace_addr_;
713+
uint64_t workspace_size_;
714+
aclOpExecutor * executor_;
715+
aclrtStream stream_;
716+
};
717+
718+
class resource_task : public cann_task {
719+
public:
720+
resource_task(std::vector<AnyAclResource>&& resources){
721+
resource_ = std::move(resources);
722+
}
723+
724+
virtual void run_task() override {
725+
resource_.clear();
726+
}
727+
private:
728+
std::vector<AnyAclResource> resource_;
729+
};
730+
731+
class free_ptr_task : public cann_task {
732+
public:
733+
free_ptr_task(void* ptr) : ptr_(ptr) {}
734+
735+
virtual void run_task() override {
736+
free(ptr_);
737+
}
738+
private:
739+
void* ptr_;
740+
};
741+
742+
class async_memcpy_task : public cann_task {
743+
public:
744+
async_memcpy_task(void* dst, const void* src, size_t size, aclrtMemcpyKind kind, aclrtStream stream)
745+
: dst_(dst), src_(src), size_(size), kind_(kind), stream_(stream) {}
746+
747+
virtual void run_task() override {
748+
749+
ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_));
750+
}
751+
private:
752+
void* dst_;
753+
const void* src_;
754+
size_t size_;
755+
aclrtMemcpyKind kind_;
756+
aclrtStream stream_;
757+
};
758+
759+
class async_memset_task : public cann_task {
760+
public:
761+
async_memset_task(void* buffer, size_t size, int32_t value, aclrtStream stream)
762+
: buffer_(buffer), size_(size), value_(value), stream_(stream) {}
763+
764+
virtual void run_task() override {
765+
ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_));
766+
}
767+
private:
768+
void* buffer_;
769+
size_t size_;
770+
int32_t value_;
771+
aclrtStream stream_;
772+
};
654773

655-
//#define ASYNC_SUBMIT
656774
/**
657775
* @brief Launches an asynchronous task using the memory allocator.
658776
*
@@ -670,94 +788,67 @@ void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
670788
* other task before this asynchronous task ends, because all tasks in the
671789
* same stream are executed in queue order.
672790
*/
673-
#ifdef ASYNC_SUBMIT
674-
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
675-
do { \
676-
uint64_t workspaceSize = 0; \
677-
aclOpExecutor * executor; \
678-
void * workspaceAddr = nullptr; \
679-
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
680-
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
681-
if (workspaceSize > 0) { \
682-
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
683-
workspaceAddr = workspace_allocator.get(); \
684-
} \
685-
auto task = std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, \
686-
workspaceSize, executor, CTX.stream()); \
687-
CTX.task_queue.submit_task(std::move(task)); \
791+
792+
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
793+
do { \
794+
uint64_t workspaceSize = 0; \
795+
aclOpExecutor * executor; \
796+
void * workspaceAddr = nullptr; \
797+
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor));\
798+
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
799+
if (workspaceSize > 0) { \
800+
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
801+
workspaceAddr = workspace_allocator.get(); \
802+
} \
803+
if (CTX.async_mode) { \
804+
auto task = \
805+
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, \
806+
executor, CTX.stream()); \
807+
CTX.task_queue.submit_task(std::move(task)); \
808+
} else { \
809+
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));\
810+
} \
688811
} while (0)
689812

690813
template <typename... Args>
691814
void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
692815
std::vector<AnyAclResource> resources;
693816
register_acl_resources(resources, std::forward<Args>(args)...);
694-
auto task = std::make_unique<resource_task>(std::move(resources));
695-
ctx.task_queue.submit_task(std::move(task));
696-
}
697-
698-
inline void ggml_cann_async_free(ggml_backend_cann_context * ctx, void * ptr) {
699-
auto task = std::make_unique<free_ptr_task>(ptr);
700-
ctx->task_queue.submit_task(std::move(task));
817+
if(ctx.async_mode) {
818+
auto task = std::make_unique<resource_task>(std::move(resources));
819+
ctx.task_queue.submit_task(std::move(task));
820+
}
701821
}
702822

703823
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst,
704824
const void * src, size_t len, aclrtMemcpyKind kind) {
705-
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
706-
ctx.task_queue.submit_task(std::move(task));
707-
}
708-
709-
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst,
710-
const void * src, size_t len, aclrtMemcpyKind kind) {
711-
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
712-
ctx->task_queue.submit_task(std::move(task));
713-
}
714-
715-
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer,
716-
size_t size, int value) {
717-
auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
718-
ctx.task_queue.submit_task(std::move(task));
719-
}
720-
#else
721-
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
722-
do { \
723-
uint64_t workspaceSize = 0; \
724-
aclOpExecutor * executor; \
725-
void * workspaceAddr = nullptr; \
726-
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
727-
if (workspaceSize > 0) { \
728-
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
729-
workspaceAddr = workspace_allocator.get(); \
730-
} \
731-
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
732-
} while (0)
733-
734-
template <typename... Args>
735-
void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
736-
GGML_UNUSED(ctx);
737-
std::vector<AnyAclResource> resources;
738-
register_acl_resources(resources, std::forward<Args>(args)...);
739-
}
740-
741-
inline void ggml_cann_async_free(ggml_backend_cann_context * ctx, void * ptr) {
742-
ACL_CHECK(aclrtSynchronizeStream(ctx->stream()));
743-
free(ptr);
744-
}
745-
746-
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst,
747-
const void * src, size_t len, aclrtMemcpyKind kind) {
748-
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream()));
825+
if (ctx.async_mode) {
826+
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
827+
ctx.task_queue.submit_task(std::move(task));
828+
} else {
829+
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream()));
830+
}
749831
}
750832

751833
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst,
752834
const void * src, size_t len, aclrtMemcpyKind kind) {
753-
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream()));
835+
if (ctx->async_mode) {
836+
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
837+
ctx->task_queue.submit_task(std::move(task));
838+
} else {
839+
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream()));
840+
}
754841
}
755842

756843
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer,
757844
size_t size, int value) {
758-
ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream()));
845+
if (ctx.async_mode) {
846+
auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
847+
ctx.task_queue.submit_task(std::move(task));
848+
} else {
849+
ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream()));
850+
}
759851
}
760-
#endif
761852

762853
/**
763854
* @brief Applies a element-wise operation to two input tensors using the CANN

0 commit comments

Comments
 (0)