Skip to content

Commit 78555da

Browse files
committed
[Executorch][custom ops] Change lib loading logic to account for package dir
Pull Request resolved: pytorch/executorch#7038 Just looking at the location of the source file. In this case custom_ops.py, can, and does, yield to wrong location depending on where you import custom_ops from. If you are importing custom_ops from another source file inside extension folder, e.g. builder.py that is in extensions/llm/export, then, I think, custom_ops gets resolved to the one installed in site-packages or pip package. But if this is imported from say examples/models/llama/source_transformations/quantized_kv_cache.py (Like in the in next PR), then it seems to resolve to the source location. In one of the CI this is /pytorch/executorch. Now depending on which directory your filepath resolves to, you will search for lib in that. This of course does not work when filepath resolves to source location. This PR changes that to resolve to package location. ghstack-source-id: 255065528 Differential Revision: [D66385480](https://our.internmc.facebook.com/intern/diff/D66385480/)
1 parent 68ca03c commit 78555da

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

extension/llm/custom_ops/custom_ops.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# pyre-unsafe
1212

1313
import logging
14-
from pathlib import Path
1514

1615
import torch
1716

@@ -23,7 +22,17 @@
2322
op2 = torch.ops.llama.fast_hadamard_transform.default
2423
assert op2 is not None
2524
except:
26-
libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*"))
25+
import glob
26+
27+
import executorch
28+
29+
executorch_package_path = executorch.__path__[0]
30+
logging.info(f"Looking for libcustom_ops_aot_lib.so in {executorch_package_path }")
31+
libs = list(
32+
glob.glob(
33+
f"{executorch_package_path}/**/libcustom_ops_aot_lib.*", recursive=True
34+
)
35+
)
2736
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
2837
logging.info(f"Loading custom ops library: {libs[0]}")
2938
torch.ops.load_library(libs[0])

0 commit comments

Comments
 (0)