Skip to content

Commit c203fc2

Browse files
committed
Runner extension and cuml DF fix
1 parent e6296c3 commit c203fc2

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

cuml/df_clsf.py

100644100755
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
float_or_int, parse_args, measure_function_time, load_data, print_output,
88
accuracy_score
99
)
10+
import cuml
1011
from cuml.ensemble import RandomForestClassifier
1112

1213
parser = argparse.ArgumentParser(description='cuml random forest '
@@ -67,7 +68,10 @@ def fit(X, y):
6768

6869

6970
def predict(X):
70-
return clf.predict(X, predict_model='GPU', num_classes=params.n_classes)
71+
prediction_args = {'predict_model': 'GPU'}
72+
if int(cuml.__version__.split('.')[1]) <= 14:
73+
prediction_args.update({'num_classes': params.n_classes})
74+
return clf.predict(X, **prediction_args)
7175

7276

7377
columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',

runner.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,24 @@ def filter_stderr(text):
3333
return text
3434

3535

36+
def filter_stdout(text):
37+
verbosity_letters = 'EWIDT'
38+
filtered, extra = '', ''
39+
for line in text.split('\n'):
40+
if line == '':
41+
continue
42+
to_remove = False
43+
for letter in verbosity_letters:
44+
if line.startswith(f'[{letter}]'):
45+
to_remove = True
46+
break
47+
if to_remove:
48+
extra += line + '\n'
49+
else:
50+
filtered += line + '\n'
51+
return filtered, extra
52+
53+
3654
def generate_cases(params):
3755
'''
3856
Generate cases for benchmarking by iterating of
@@ -284,12 +302,17 @@ class GenerationArgs:
284302
verbose_print(command)
285303
if not args.dummy_run:
286304
stdout, stderr = read_output_from_command(command)
305+
stdout, extra_stdout = filter_stdout(stdout)
287306
stderr = filter_stderr(stderr)
307+
if extra_stdout != '':
308+
stderr += f'CASE {case} EXTRA OUTPUT:\n' \
309+
+ f'{extra_stdout}\n'
288310
if args.output_format == 'json':
289311
try:
290312
json_result['results'].extend(json.loads(stdout))
291313
except json.JSONDecodeError as decoding_exception:
292-
stderr += str(decoding_exception) + '\n'
314+
stderr += f'CASE {case} JSON DECODING ERROR:\n' \
315+
+ f'{decoding_exception}\n{stdout}\n'
293316
elif args.output_format == 'csv':
294317
csv_result += stdout + '\n'
295318
if stderr != '':

0 commit comments

Comments
 (0)