-
Notifications
You must be signed in to change notification settings - Fork 31
Zq/add specified autocompare #785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 14 commits
60eb031
76d96f6
63a4ff7
538cf2f
5348043
6d8a8e3
4790315
ba73de3
e97db80
8e95f2f
1ba0e04
f27200b
e534ce2
f2d7ff9
5598c1b
3c63c9f
5cdba1a
9d9923e
458233c
389a267
c8dbb34
8d40cf9
d5a38d0
16585db
0a5cd89
17ce56f
5413aec
6931b62
dedb0c5
94b98f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里忘了删吧
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,16 +8,10 @@ | |
|
|
||
| #include <torch/library.h> | ||
|
|
||
| #include "csrc_dipu/aten/ops/OpRegexMatch.hpp" | ||
| #include "csrc_dipu/aten/ops/OpUtils.hpp" | ||
|
|
||
| namespace dipu { | ||
|
|
||
| bool get_force_fallback(const char* opname); | ||
|
|
||
| }; // namespace dipu | ||
|
|
||
| namespace at { | ||
|
|
||
| void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, | ||
| torch::jit::Stack* stack); | ||
|
|
||
|
|
@@ -52,11 +46,32 @@ void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, | |
| // It mat be necessary to determine whether to keep torchop default impl | ||
| // for non-custom ops through function dipuKeepTorchopDefaultImpl firstly in the | ||
| // future, and we use force fallback to keep torchop default impl now. | ||
| #define DIOPI_ATEN_FUNC(opname, diopiFunc, wapperFunc) \ | ||
| #define addAutoCompare(wrapperFunc) wrapperFunc##_autocompare | ||
lljbash marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| #define DIOPI_ATEN_FUNC(opname, diopiFunc, wrapperFunc) \ | ||
| do { \ | ||
|
||
| if ((reinterpret_cast<void*>(diopiFunc) != nullptr) && \ | ||
| (!dipu::whetherOpMatch(opname, fallbackMatchers))) { \ | ||
| if (dipu::whetherAutoCompare(opname, autocompareMatchers)) { \ | ||
| m.impl(opname, TORCH_FN(addAutoCompare(wrapperFunc))); \ | ||
| } else { \ | ||
| m.impl(opname, TORCH_FN(wrapperFunc)); \ | ||
| } \ | ||
| } else { \ | ||
| if ((reinterpret_cast<void*>(diopiFunc) == nullptr)) { \ | ||
| DIPU_OP_LOG_WARNING_ONCE(#diopiFunc << " is not yet implemented, "); \ | ||
| } else { \ | ||
| DIPU_OP_LOG_WARNING_ONCE("force fallback has been set, "); \ | ||
| } \ | ||
| DIPU_OP_LOG_WARNING_ONCE((opname) << " will be fallback to cpu" \ | ||
| << "\n"); \ | ||
lljbash marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } \ | ||
| } while (false); | ||
|
|
||
| #define DIOPI_ATEN_FUNC_DISABLE_AUTOCOMPARE(opname, diopiFunc, wrapperFunc) \ | ||
| do { \ | ||
|
||
| if ((reinterpret_cast<void*>(diopiFunc) != nullptr) && \ | ||
| (!dipu::get_force_fallback(opname))) { \ | ||
| m.impl(opname, TORCH_FN(wapperFunc)); \ | ||
| (!dipu::whetherOpMatch(opname, fallbackMatchers))) { \ | ||
| m.impl(opname, TORCH_FN(wrapperFunc)); \ | ||
| } else { \ | ||
| if ((reinterpret_cast<void*>(diopiFunc) == nullptr)) { \ | ||
| DIPU_OP_LOG_WARNING_ONCE(#diopiFunc << " is not yet implemented, "); \ | ||
|
|
@@ -71,14 +86,15 @@ void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, | |
| // Determine whether to keep torchop default impl for custom ops through | ||
| // function dipuKeepTorchopDefaultImpl firstly. | ||
| #define DIOPI_ATEN_FUNC_CUSTOM_FALLBACK(opname, diopi_func, force_fallback, \ | ||
| wapper_func, custom_fallback_func) \ | ||
| wrapper_func, custom_fallback_func) \ | ||
| do { \ | ||
| if (dipu::native::dipuKeepTorchopDefaultImpl(opname)) { \ | ||
| break; \ | ||
| } \ | ||
| if ((reinterpret_cast<void*>(diopi_func) != nullptr) && \ | ||
| !((force_fallback) || dipu::get_force_fallback(opname))) { \ | ||
| m.impl(opname, TORCH_FN(wapper_func)); \ | ||
| !((force_fallback) || \ | ||
| dipu::whetherOpMatch(opname, fallbackMatchers))) { \ | ||
| m.impl(opname, TORCH_FN(wrapper_func)); \ | ||
| } else { \ | ||
| if ((reinterpret_cast<void*>(diopi_func) == nullptr)) { \ | ||
| DIPU_OP_LOG_WARNING_ONCE(#diopi_func << " is not yet implemented, "); \ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.