diff --git a/plot.py b/plot.py index 796c160..cab7c57 100644 --- a/plot.py +++ b/plot.py @@ -66,7 +66,7 @@ def parse_input_size(name): splits = name.split('/') if len(splits) == 1: return 1 - return int(splits[1]) + return int(splits[-1]) def read_data(args): @@ -77,7 +77,7 @@ def read_data(args): msg = 'Could not parse the benchmark data. Did you forget "--benchmark_format=csv"?' logging.error(msg) exit(1) - data['label'] = data['name'].apply(lambda x: x.split('/')[0]) + data['label'] = data['name'].apply(lambda x: x.split('/')[-2]) data['input'] = data['name'].apply(parse_input_size) data[args.metric] = data[args.metric].apply(TRANSFORMS[args.transform]) return data