|
16 | 16 |
|
17 | 17 | import sys
|
18 | 18 | import os
|
19 |
| -from cpuinfo import get_cpu_info |
20 | 19 |
|
21 | 20 | core_suffix = 'so'
|
22 | 21 | if os.name == 'nt':
|
|
57 | 56 | except Exception as e:
|
58 | 57 | raise e
|
59 | 58 |
|
60 |
| -load_noavx = False |
61 | 59 |
|
62 |
| -has_avx = False |
63 |
| -if sys.platform == 'darwin': |
64 |
| - try: |
65 |
| - has_avx = os.popen('sysctl machdep.cpu.features | grep -i avx').read( |
66 |
| - ) != '' |
67 |
| - except Exception as e: |
68 |
| - sys.stderr.write( |
69 |
| - 'Can not get the AVX flag from machdep.cpu.features.\n') |
70 |
| - if not has_avx: |
| 60 | +def avx_supported(): |
| 61 | + """ |
| 62 | + Whether current system(Linux, MacOS, Windows) is supported with AVX. |
| 63 | + """ |
| 64 | + import platform |
| 65 | + from .. import compat as cpt |
| 66 | + sysstr = platform.system().lower() |
| 67 | + has_avx = False |
| 68 | + if sysstr == 'linux': |
| 69 | + try: |
| 70 | + has_avx = os.popen('cat /proc/cpuinfo | grep -i avx').read() != '' |
| 71 | + except Exception as e: |
| 72 | + sys.stderr.write('Can not get the AVX flag from /proc/cpuinfo.\n' |
| 73 | + 'The original error is: %s\n' % |
| 74 | + cpt.get_exception_message(e)) |
| 75 | + return has_avx |
| 76 | + elif sysstr == 'darwin': |
71 | 77 | try:
|
72 | 78 | has_avx = os.popen(
|
73 |
| - 'sysctl machdep.cpu.leaf7_features | grep -i avx').read() != '' |
| 79 | + 'sysctl machdep.cpu.features | grep -i avx').read() != '' |
74 | 80 | except Exception as e:
|
75 | 81 | sys.stderr.write(
|
76 |
| - 'Can not get the AVX flag from machdep.cpu.leaf7_features.\n') |
77 |
| -else: |
78 |
| - has_avx = 'avx' in get_cpu_info()['flags'] |
| 82 | + 'Can not get the AVX flag from machdep.cpu.features.\n' |
| 83 | + 'The original error is: %s\n' % cpt.get_exception_message(e)) |
| 84 | + if not has_avx: |
| 85 | + try: |
| 86 | + has_avx = os.popen( |
| 87 | + 'sysctl machdep.cpu.leaf7_features | grep -i avx').read( |
| 88 | + ) != '' |
| 89 | + except Exception as e: |
| 90 | + sys.stderr.write( |
| 91 | + 'Can not get the AVX flag from machdep.cpu.leaf7_features.\n' |
| 92 | + 'The original error is: %s\n' % |
| 93 | + cpt.get_exception_message(e)) |
| 94 | + return has_avx |
| 95 | + elif sysstr == 'windows': |
| 96 | + import ctypes |
| 97 | + ONE_PAGE = ctypes.c_size_t(0x1000) |
79 | 98 |
|
80 |
| -if has_avx: |
| 99 | + def asm_func(code_str, restype=ctypes.c_uint32, argtypes=()): |
| 100 | + # Call the code_str as a function |
| 101 | + # Alloc 1 page to ensure the protection |
| 102 | + pfnVirtualAlloc = ctypes.windll.kernel32.VirtualAlloc |
| 103 | + pfnVirtualAlloc.restype = ctypes.c_void_p |
| 104 | + MEM_COMMIT = ctypes.c_ulong(0x1000) |
| 105 | + PAGE_READWRITE = ctypes.c_ulong(0x4) |
| 106 | + address = pfnVirtualAlloc(None, ONE_PAGE, MEM_COMMIT, |
| 107 | + PAGE_READWRITE) |
| 108 | + if not address: |
| 109 | + raise Exception("Failed to VirtualAlloc") |
| 110 | + |
| 111 | + # Copy the code into the memory segment |
| 112 | + memmove = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_void_p, |
| 113 | + ctypes.c_void_p, |
| 114 | + ctypes.c_size_t)(ctypes._memmove_addr) |
| 115 | + if memmove(address, code_str, len(code_str)) < 0: |
| 116 | + raise Exception("Failed to memmove") |
| 117 | + |
| 118 | + # Enable execute permissions |
| 119 | + PAGE_EXECUTE = ctypes.c_ulong(0x10) |
| 120 | + pfnVirtualProtect = ctypes.windll.kernel32.VirtualProtect |
| 121 | + res = pfnVirtualProtect( |
| 122 | + ctypes.c_void_p(address), ONE_PAGE, PAGE_EXECUTE, |
| 123 | + ctypes.byref(ctypes.c_ulong(0))) |
| 124 | + if not res: |
| 125 | + raise Exception("Failed VirtualProtect") |
| 126 | + |
| 127 | + # Flush instruction cache |
| 128 | + pfnGetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess |
| 129 | + pfnGetCurrentProcess.restype = ctypes.c_void_p |
| 130 | + prochandle = ctypes.c_void_p(pfnGetCurrentProcess()) |
| 131 | + res = ctypes.windll.kernel32.FlushInstructionCache( |
| 132 | + prochandle, ctypes.c_void_p(address), ONE_PAGE) |
| 133 | + if not res: |
| 134 | + raise Exception("Failed FlushInstructionCache") |
| 135 | + |
| 136 | + # Cast the memory to function |
| 137 | + functype = ctypes.CFUNCTYPE(restype, *argtypes) |
| 138 | + func = functype(address) |
| 139 | + return func, address |
| 140 | + |
| 141 | + # http://en.wikipedia.org/wiki/CPUID#EAX.3D1:_Processor_Info_and_Feature_Bits |
| 142 | + # mov eax,0x1; cpuid; mov cx, ax; ret |
| 143 | + code_str = b"\xB8\x01\x00\x00\x00\x0f\xa2\x89\xC8\xC3" |
| 144 | + avx_bit = 28 |
| 145 | + retval = 0 |
| 146 | + try: |
| 147 | + # Convert the code_str into a function that returns uint |
| 148 | + func, address = asm_func(code_str) |
| 149 | + retval = func() |
| 150 | + ctypes.windll.kernel32.VirtualFree( |
| 151 | + ctypes.c_void_p(address), ctypes.c_size_t(0), ONE_PAGE) |
| 152 | + except Exception as e: |
| 153 | + sys.stderr.write('Failed getting the AVX flag on Windows.\n' |
| 154 | + 'The original error is: %s\n' % |
| 155 | + cpt.get_exception_message(e)) |
| 156 | + return (retval & (1 << avx_bit)) > 0 |
| 157 | + else: |
| 158 | + sys.stderr.write('Do not get AVX flag on %s\n' % sysstr) |
| 159 | + return False |
| 160 | + |
| 161 | + |
| 162 | +load_noavx = False |
| 163 | + |
| 164 | +if avx_supported(): |
81 | 165 | try:
|
82 | 166 | from .core_avx import *
|
83 | 167 | from .core_avx import __doc__, __file__, __name__, __package__
|
|
91 | 175 | from .core_avx import _set_fuse_parameter_memory_size
|
92 | 176 | from .core_avx import _is_dygraph_debug_enabled
|
93 | 177 | from .core_avx import _dygraph_debug_level
|
94 |
| - except ImportError as e: |
| 178 | + except Exception as e: |
95 | 179 | if has_avx_core:
|
96 | 180 | raise e
|
97 | 181 | else:
|
| 182 | + from .. import compat as cpt |
98 | 183 | sys.stderr.write(
|
99 | 184 | 'WARNING: Do not have avx core. You may not build with AVX, '
|
100 | 185 | 'but AVX is supported on local machine.\n You could build paddle '
|
101 |
| - 'WITH_AVX=ON to get better performance.\n') |
| 186 | + 'WITH_AVX=ON to get better performance.\n' |
| 187 | + 'The original error is: %s\n' % cpt.get_exception_message(e)) |
102 | 188 | load_noavx = True
|
103 |
| - except Exception as e: |
104 |
| - raise e |
105 | 189 | else:
|
106 | 190 | load_noavx = True
|
107 | 191 |
|
|
119 | 203 | from .core_noavx import _set_fuse_parameter_memory_size
|
120 | 204 | from .core_noavx import _is_dygraph_debug_enabled
|
121 | 205 | from .core_noavx import _dygraph_debug_level
|
122 |
| - except ImportError as e: |
| 206 | + except Exception as e: |
123 | 207 | if has_noavx_core:
|
124 | 208 | sys.stderr.write(
|
125 | 209 | 'Error: Can not import noavx core while this file exists ' +
|
126 | 210 | current_path + os.sep + 'core_noavx.' + core_suffix + '\n')
|
127 | 211 | raise e
|
128 |
| - except Exception as e: |
129 |
| - raise e |
|
0 commit comments