Skip to content

Commit d3fdf17

Browse files
authored
[xpu] fix static kernel pass (#9672)
1 parent 954b732 commit d3fdf17

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ bool XPUKernelScoreCmp(const std::pair<float, std::unique_ptr<KernelBase>>& a,
3333
}
3434

3535
void XPUStaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
36+
Init();
37+
3638
kernel_pick_factors_.ConsiderTarget();
3739
kernel_pick_factors_.ConsiderPrecision();
3840
kernel_pick_factors_.ConsiderDataLayout();

lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,17 @@ namespace mir {
4242
*/
4343
class XPUStaticKernelPickPass : public mir::StmtPass {
4444
public:
45-
XPUStaticKernelPickPass() {
45+
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
46+
47+
const core::KernelPickFactor& kernel_pick_factors() const {
48+
return kernel_pick_factors_;
49+
}
50+
core::KernelPickFactor* mutable_kernel_pick_factors() {
51+
return &kernel_pick_factors_;
52+
}
53+
54+
private:
55+
void Init() {
4656
#ifdef LITE_WITH_XPU
4757
// get xpu device type
4858
int cur_dev_idx = 0;
@@ -72,16 +82,6 @@ class XPUStaticKernelPickPass : public mir::StmtPass {
7282
#endif
7383
}
7484

75-
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
76-
77-
const core::KernelPickFactor& kernel_pick_factors() const {
78-
return kernel_pick_factors_;
79-
}
80-
core::KernelPickFactor* mutable_kernel_pick_factors() {
81-
return &kernel_pick_factors_;
82-
}
83-
84-
private:
8585
// Score the kernel.
8686
size_t KernelGrade(lite::mir::Node* node,
8787
const lite::KernelBase& kernel,

0 commit comments

Comments
 (0)