diff --git a/backends/npu/kernels/conv2d_kernel.cc b/backends/npu/kernels/conv2d_kernel.cc index c7fac1eb153..60b5f7eb177 100644 --- a/backends/npu/kernels/conv2d_kernel.cc +++ b/backends/npu/kernels/conv2d_kernel.cc @@ -96,6 +96,11 @@ void Conv2dKernel(const Context& dev_ctx, int groups, const std::string& data_format, phi::DenseTensor* output) { + std::once_flag npu_jit_compile_flag; + std::call_once(npu_jit_compile_flag, + UpdateBoolFlag, + "FLAGS_npu_jit_compile", + &FLAGS_npu_jit_compile); if (FLAGS_npu_jit_compile) { aclSetCompileopt(ACL_OP_JIT_COMPILE, "disable"); } @@ -328,6 +333,11 @@ void Conv2DGradKernel(const Context& dev_ctx, const std::string& data_format, phi::DenseTensor* input_grad, phi::DenseTensor* filter_grad) { + std::once_flag npu_jit_compile_flag; + std::call_once(npu_jit_compile_flag, + UpdateBoolFlag, + "FLAGS_npu_jit_compile", + &FLAGS_npu_jit_compile); if (FLAGS_npu_jit_compile) { aclSetCompileopt(ACL_OP_JIT_COMPILE, "disable"); } diff --git a/backends/npu/kernels/funcs/npu_op_runner.cc b/backends/npu/kernels/funcs/npu_op_runner.cc index 8222c56b935..4f194167dca 100644 --- a/backends/npu/kernels/funcs/npu_op_runner.cc +++ b/backends/npu/kernels/funcs/npu_op_runner.cc @@ -620,6 +620,7 @@ void NpuOpRunner::Run(aclrtStream stream, bool sync) const { static std::once_flag jit_compile_flag; std::call_once(jit_compile_flag, [&]() { + UpdateBoolFlag("FLAGS_npu_jit_compile", &FLAGS_npu_jit_compile); if (FLAGS_npu_jit_compile) { aclSetCompileopt(ACL_OP_JIT_COMPILE, "enable"); } else { diff --git a/backends/npu/kernels/pool2d_kernel.cc b/backends/npu/kernels/pool2d_kernel.cc index 112b9bafb2d..a1cb1b06ec1 100644 --- a/backends/npu/kernels/pool2d_kernel.cc +++ b/backends/npu/kernels/pool2d_kernel.cc @@ -444,6 +444,11 @@ void AclopPool2dGradKernel(const Context& dev_ctx, cast_out_tensor = out_tensor; } + std::once_flag npu_jit_compile_flag; + std::call_once(npu_jit_compile_flag, + UpdateBoolFlag, + "FLAGS_npu_jit_compile", + &FLAGS_npu_jit_compile); if (!FLAGS_npu_jit_compile) { aclSetCompileopt(ACL_OP_JIT_COMPILE, "enable"); } diff --git a/backends/npu/runtime/flags.h b/backends/npu/runtime/flags.h index 271ccbd032a..7c67c844999 100644 --- a/backends/npu/runtime/flags.h +++ b/backends/npu/runtime/flags.h @@ -49,6 +49,9 @@ #ifndef BACKENDS_NPU_RUNTIME_FLAGS_H_ #define BACKENDS_NPU_RUNTIME_FLAGS_H_ +#include +#include + #include "gflags/gflags.h" #define FLAGS_DEFINE_bool(name, value, meaning) \ @@ -77,4 +80,7 @@ #define EnvToUInt(envname, dflt) \ (!getenv(envname) ? (dflt) : strtoul(getenv(envname), NULL, 10)) +inline void UpdateBoolFlag(const std::string& envname, bool* flag) { + *flag = EnvToBool(envname.c_str(), *flag); +} #endif // BACKENDS_NPU_RUNTIME_FLAGS_H_