33import os
44import shutil
55import sys
6- import warnings
76from typing import List
87
9- import pybind11 # noqa: F401
10- import setuptools
8+ import setuptools # noqa: F401
119import torch
1210from setuptools .command .build_ext import build_ext
1311from torch .utils .cpp_extension import _TORCH_PATH
@@ -336,14 +334,8 @@ def _prepare_ldflags(extra_ldflags, verbose, is_standalone):
336334 extra_ldflags .append (f"-Wl,-rpath,{ TORCH_LIB_PATH } " )
337335
338336 library_dirs = library_paths ()
339- # Append oneMKL link parameters, detailed please reference:
340- # https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl-link-line-advisor.html
341337 oneapi_link_args += [f"-L{ x } " for x in library_dirs ]
342- # oneapi_link_args += ['-fsycl-device-code-split=per_kernel']
343- oneapi_link_args += ["-Wl,--start-group" ]
344- oneapi_link_args += [f"{ x } " for x in get_one_api_help ().get_onemkl_libraries ()]
345- oneapi_link_args += ["-Wl,--end-group" ]
346- oneapi_link_args += ["-ldnnl" , "-lOpenCL" , "-lpthread" , "-lm" , "-ldl" ]
338+ oneapi_link_args += ["-Wl" , "-ldnnl" , "-lOpenCL" , "-lpthread" , "-lm" , "-ldl" ]
347339 oneapi_link_args += ['-lintel-ext-pt-gpu' ]
348340
349341 extra_ldflags += oneapi_link_args
@@ -356,61 +348,20 @@ def _get_dpcpp_root():
356348 return dpcpp_root
357349
358350
359- def _get_onemkl_root ():
360- # TODO: Need to decouple with toolchain env scripts
361- path = os .getenv ("MKLROOT" )
362- return path
363-
364-
365- def _get_onednn_root ():
366- # TODO: Need to decouple with toolchain env scripts
367- path = os .getenv ("DNNLROOT" )
368- return path
369-
370-
371351class _one_api_help :
372352 __dpcpp_root = None
373- __onemkl_root = None
374- __onednn_root = None
375353 __ipex_root = None
376354
377355 def __init__ (self ):
356+ import intel_extension_for_pytorch
378357 self .__dpcpp_root = _get_dpcpp_root ()
379- self .__onemkl_root = _get_onemkl_root ()
380- self .__onednn_root = _get_onednn_root ()
381-
382- infos = os .popen ("pip show intel_extension_for_pytorch" ).read ().split ("\n " )
383- for info in infos :
384- if "Location" in info :
385- ipex_path = info [10 :]
386- ipex_path = os .path .join (ipex_path , "intel_extension_for_pytorch" )
387-
388- self .__ipex_root = ipex_path
389-
390- self .check_onednn_cfg ()
358+ self .__ipex_root = os .path .dirname (intel_extension_for_pytorch .__file__ )
391359 self .check_dpcpp_cfg ()
392- self .check_onemkl_cfg ()
393-
394- def check_onemkl_cfg (self ):
395- if self .__onemkl_root is None :
396- raise "Didn't detect mkl root. Please source <oneapi_dir>/mkl/<version>/env/vars.sh "
397-
398- def check_onednn_cfg (self ):
399- if self .__onednn_root is None :
400- raise "Didn't detect dnnl root. Please source <oneapi_dir>/dnnl/<version>/env/vars.sh "
401- else :
402- warnings .warn (
403- "This extension has static linked onednn library. Please attaction to \
404- that, this path of onednn version maybe not match with the built-in version."
405- )
406360
407361 def check_dpcpp_cfg (self ):
408362 if self .__dpcpp_root is None :
409363 raise "Didn't detect dpcpp root. Please source <oneapi_dir>/compiler/<version>/env/vars.sh "
410364
411- def get_ipex_include_dir (self ):
412- return [os .path .join (self .__ipex_root , "include" )]
413-
414365 def get_ipex_lib_dir (self ):
415366 return [os .path .join (self .__ipex_root , "lib" )]
416367
@@ -420,59 +371,20 @@ def get_dpcpp_include_dir(self):
420371 os .path .join (self .__dpcpp_root , "linux" , "include" , "sycl" ),
421372 ]
422373
423- def get_onemkl_include_dir (self ):
424- return [os .path .join (self .__onemkl_root , "include" )]
425-
426- def get_onednn_include_dir (self ):
427- return [os .path .join (self .__onednn_root , "include" )]
428-
429- def get_onednn_lib_dir (self ):
430- return [os .path .join (self .__onednn_root , "lib" )]
431-
432- def is_onemkl_ready (self ):
433- if self .__onemkl_root is None :
434- return False
435- return True
436-
437- def is_onednn_ready (self ):
438- if self .__onednn_root is None :
439- return False
440- return True
441-
442374 def get_library_dirs (self ):
443375 library_dirs = []
444376 library_dirs += [f"{ x } " for x in self .get_ipex_lib_dir ()]
445- library_dirs += [f"{ x } " for x in self .get_onednn_lib_dir ()]
446377 return library_dirs
447378
448379 def get_include_dirs (self ):
449380 include_dirs = []
450381 include_dirs += [f"{ x } " for x in self .get_dpcpp_include_dir ()]
451- include_dirs += [f"{ x } " for x in self .get_onemkl_include_dir ()]
452- include_dirs += [f"{ x } " for x in self .get_onednn_include_dir ()]
453- include_dirs += [f"{ x } " for x in self .get_ipex_include_dir ()]
454382 return include_dirs
455383
456- def get_onemkl_libraries (self ):
457- MKLROOT = self .__onemkl_root
458- return [
459- f"{ MKLROOT } /lib/intel64/libmkl_sycl.a" ,
460- f"{ MKLROOT } /lib/intel64/libmkl_intel_ilp64.a" ,
461- f"{ MKLROOT } /lib/intel64/libmkl_sequential.a" ,
462- f"{ MKLROOT } /lib/intel64/libmkl_core.a" ,
463- ]
464-
465384
466- def get_pytorch_include_dir ():
467- lib_include = os .path .join (_TORCH_PATH , "include" )
468- paths = [
469- lib_include ,
470- # Remove this once torch/torch.h is officially no longer supported for C++ extensions.
471- os .path .join (lib_include , "torch" , "csrc" , "api" , "include" ),
472- # Some internal (old) Torch headers don't properly prefix their includes,
473- # so we need to pass -Itorch/lib/include/TH as well.
474- os .path .join (lib_include , "TH" ),
475- ]
385+ def get_pytorch_ipex_onemkl_include_dir ():
386+ import intel_extension_for_pytorch
387+ paths = intel_extension_for_pytorch .xpu .cpp_extension .include_paths ()
476388 return paths
477389
478390
@@ -491,21 +403,13 @@ def include_paths() -> List[str]:
491403 if use_profile ():
492404 # add pytorch include directories
493405 paths = []
494- paths += get_pytorch_include_dir ()
406+ paths += get_pytorch_ipex_onemkl_include_dir ()
495407
496408 # add oneAPI include directories
497409 paths += get_one_api_help ().get_include_dirs ()
498410 else :
499- infos = os .popen ("pip show pybind11" ).read ().split ("\n " )
500- for info in infos :
501- if "Location" in info :
502- pybind11_path = info [10 :]
503-
504- pybind11_path = os .path .join (pybind11_path , 'pybind11/include' )
505- if not os .path .exists (pybind11_path ):
506- raise Exception ("Didn't found pybind11 in conda site-packages, pls try pip install pybind11" )
507-
508- paths = [pybind11_path , ]
411+ import torch .utils .cpp_extension
412+ paths = torch .utils .cpp_extension .include_paths ()
509413
510414 return paths
511415
0 commit comments