Skip to content

Commit 9ad9472

Browse files
wrap lshw in list, always try nvidia-smi and rocm-smi
1 parent d156ea7 commit 9ad9472

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

kernel_tuner/file_utils.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,27 +165,41 @@ def get_device_query(target):
165165
raise ValueError("get_device_query target not supported")
166166

167167

168-
def store_metadata_file(metadata_filename, target="nvidia"):
168+
def store_metadata_file(metadata_filename):
169169
""" Store the metadata about the current hardware and software environment in a JSON output file
170170
171171
This function produces a JSON file that adheres to the T4 auto-tuning metadata JSON schema.
172172
173173
:param metadata_filename: Name of the to be created metadata file
174174
:type metadata_filename: string
175175
176-
:param target: Target specifies whether to include the metadata of the 'nvidia' or 'amd' GPUs in the system
177-
:type target: string
178-
179176
"""
180177
metadata_filename = filename_ensure_json_extension(metadata_filename)
181178
metadata = {}
182179

183180
# lshw only works on Linux, this intentionally raises a FileNotFoundError when ran on systems that do not have it
184181
lshw_out = subprocess.run(["lshw", "-json"], capture_output=True)
185-
metadata["hardware"] = dict(lshw=json.loads(lshw_out.stdout))
186182

187-
# only works if nvidia-smi (for NVIDIA) or rocm-smi (for AMD) is present, raises FileNotFoundError when not present
188-
device_query = get_device_query(target)
183+
# sometimes lshw outputs a list of length 1, sometimes just as a dict, schema wants a list
184+
lshw_string = lshw_out.stdout.decode('utf-8').strip()
185+
if lshw_string[0] == '{' and lshw_string[-1] == '}':
186+
lshw_string = '[' + lshw_string + ']'
187+
188+
metadata["hardware"] = dict(lshw=json.loads(lshw_string))
189+
190+
# attempts to use nvidia-smi or rocm-smi if present
191+
device_query = {}
192+
try:
193+
device_query['nvidia-smi'] = get_device_query("nvidia")
194+
except FileNotFoundError:
195+
# ignore if nvidia-smi is not found
196+
pass
197+
198+
try:
199+
device_query['rocm-smi'] = get_device_query("amd")
200+
except FileNotFoundError:
201+
# ignore if rocm-smi is not found
202+
pass
189203

190204
metadata["environment"] = dict(device_query=device_query,
191205
requirements=get_dependencies())

0 commit comments

Comments
 (0)