Skip to content

Commit e6b38bb

Browse files
authored
Fix setup develop (#2748)
* Fix setup.py develop workflow * up * up
1 parent 758f744 commit e6b38bb

File tree

4 files changed

+29
-7
lines changed

4 files changed

+29
-7
lines changed

setup.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,21 @@ def build_cmake(self, ext):
317317
if not os.path.exists(self.build_temp):
318318
os.makedirs(self.build_temp)
319319

320+
# Get the expected extension file name that Python will look for
321+
# We force CMake to use this library name
322+
ext_filename = os.path.basename(self.get_ext_filename(ext.name))
323+
ext_basename = os.path.splitext(ext_filename)[0]
324+
320325
subprocess.check_call(
321326
[
322327
"cmake",
323328
ext.cmake_lists_dir,
324329
]
325330
+ ext.cmake_args
326-
+ ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir],
331+
+ [
332+
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
333+
"-DTORCHAO_CMAKE_EXT_SO_NAME=" + ext_basename,
334+
],
327335
cwd=self.build_temp,
328336
)
329337
subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp)
@@ -708,7 +716,7 @@ def bool_to_on_off(value):
708716

709717
ext_modules.append(
710718
CMakeExtension(
711-
"torchao.experimental",
719+
"torchao._experimental_aten_ops",
712720
cmake_lists_dir="torchao/experimental",
713721
cmake_args=(
714722
[

torchao/experimental/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,18 @@ if(TORCHAO_BUILD_ATEN_OPS)
136136
ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp
137137
)
138138
list(TRANSFORM _torchao_op_srcs_aten PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/")
139+
140+
# Use the Python extension name if provided
139141
add_library(torchao_ops_aten SHARED ${_torchao_op_srcs_aten})
142+
if(DEFINED TORCHAO_CMAKE_EXT_SO_NAME)
143+
message(STATUS "Setting output name to: ${TORCHAO_CMAKE_EXT_SO_NAME}.so")
144+
set_target_properties(torchao_ops_aten PROPERTIES
145+
OUTPUT_NAME ${TORCHAO_CMAKE_EXT_SO_NAME}
146+
PREFIX "" # Remove "lib" prefix for Python extensions
147+
SUFFIX ".so" # Add ".so" suffix for Python extensions
148+
)
149+
endif()
150+
140151
target_link_torchao_parallel_backend(torchao_ops_aten "${TORCHAO_PARALLEL_BACKEND}")
141152
if (TORCHAO_BUILD_CPU_AARCH64)
142153
target_link_libraries(torchao_ops_aten PRIVATE torchao_kernels_aarch64)

torchao/experimental/op_lib.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,18 @@
2222

2323

2424
def find_and_load_libtorchao_ops(potential_paths):
25+
"""
26+
Finds and loads torchao._experimental_aten_ops from one of the provided paths
27+
"""
28+
2529
for lib_path in potential_paths:
26-
libs = list(lib_path.glob("libtorchao_ops_aten.*"))
30+
libs = list(lib_path.glob("_experimental_aten_ops.*"))
2731

2832
if not libs:
2933
continue
3034

3135
assert len(libs) == 1, (
32-
f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}"
36+
f"Expected to find one _experimental_aten_ops.* library at {lib_path}, but found {len(libs)}"
3337
)
3438

3539
target_lib = libs[0]

torchao/experimental/tests/test_embedding_xbit_quantizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,9 @@ def test_shared_embedding(self):
183183
self.assertTrue(torch.allclose(result, exported_result))
184184

185185
# Check the shared_embedding and linear ops use the same lifted weight
186-
weight = "b_getattr_l__fn_____0___unembedding_packed_weights"
187186
expected_lines = [
188-
f"torch.ops.torchao._shared_embedding_4bit.default({weight}, 4096, 131, 4096, reshape)",
189-
f"torch.ops.torchao._linear_8bit_act_4bit_weight.default(linear, {weight}, 4096, 131, 4096)",
187+
"torch.ops.torchao._shared_embedding_4bit.default",
188+
"torch.ops.torchao._linear_8bit_act_4bit_weight.default",
190189
]
191190
for line in expected_lines:
192191
FileCheck().check_count(line, 1, exactly=True).run(

0 commit comments

Comments
 (0)