@@ -125,11 +125,33 @@ def run(self):
125125 super ().run ()
126126
127127
128- setup (name = "triton-kernels-benchmark" , packages = [
129- "triton_kernels_benchmark" ,
130- ], package_dir = {
131- "triton_kernels_benchmark" : "triton_kernels_benchmark" ,
132- }, package_data = {"triton_kernels_benchmark" : ["xetla_kernel.cpython-*.so" ]}, cmdclass = {
133- "build_ext" : build_ext ,
134- "clean" : clean ,
135- }, ext_modules = [CMakeExtension ("triton_kernels_benchmark" )])
128+ def get_git_commit_hash (length = 8 ):
129+ try :
130+ cmd = ["git" , "rev-parse" , f"--short={ length } " , "HEAD" ]
131+ return "+git{}" .format (subprocess .check_output (cmd ).strip ().decode ("utf-8" ))
132+ except Exception :
133+ return ""
134+
135+
136+ def get_install_requires ():
137+ install_requires = ["torch" , "matplotlib" , "pandas" , "tabulate" ] # yapf: disable
138+ return install_requires
139+
140+
141+ setup (
142+ name = "triton-kernels-benchmark" ,
143+ version = "3.1.0" + get_git_commit_hash (),
144+ packages = ["triton_kernels_benchmark" ],
145+ install_requires = get_install_requires (),
146+ package_dir = {"triton_kernels_benchmark" : "triton_kernels_benchmark" },
147+ package_data = {"triton_kernels_benchmark" : ["xetla_kernel.cpython-*.so" ]},
148+ cmdclass = {
149+ "build_ext" : build_ext ,
150+ "clean" : clean ,
151+ },
152+ ext_modules = [CMakeExtension ("triton_kernels_benchmark" )],
153+ extra_require = {
154+ "ipex" : ["numpy<=2.0" , "intel-extension-for-pytorch=2.1.10" ],
155+ "pytorch" : ["torch>=2.6" ]
156+ },
157+ )
0 commit comments