Skip to content

Commit 61dbd99

Browse files
authored
Merge pull request #63 from MoseleyBioinformaticsLab/granular
improves naming and updates documentation accordingly
2 parents 92aae77 + a953a67 commit 61dbd99

File tree

8 files changed

+34
-33
lines changed

8 files changed

+34
-33
lines changed

src/gpu_tracker/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
--guuids=<gpu-uuids> Comma separated list of the UUIDs of the GPUs for which to track utilization e.g. gpu-uuid1,gpu-uuid2,etc. Defaults to all the GPUs in the system.
2121
--disable-logs If set, warnings are suppressed during tracking. Otherwise, the Tracker logs warnings as usual.
2222
--gb=<gpu-brand> The brand of GPU to profile. Valid values are nvidia and amd. Defaults to the brand of GPU detected in the system, checking NVIDIA first.
23-
--tf=<tracking-file> If specified, stores the individual resource usage measurements at each iteration. Valid file formats are CSV (.csv) and SQLite (.sqlite) where the SQLite file format stores the data in a table called "tracking" and allows for more efficient querying.
23+
--tf=<tracking-file> If specified, stores the individual resource usage measurements at each iteration. Valid file formats are CSV (.csv) and SQLite (.sqlite) where the SQLite file format stores the data in a table called "data" and allows for more efficient querying.
2424
"""
2525
import docopt as doc
2626
import subprocess as subp

src/gpu_tracker/_helper_classes.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,20 +141,20 @@ class _TimepointUsage:
141141
class _SubTrackerLog:
142142
class CodeBlockPosition(enum.Enum):
143143
START = 'START'
144-
END = 'END'
144+
STOP = 'STOP'
145145
code_block_name: str
146146
position: CodeBlockPosition
147147
timestamp: float
148148

149149

150-
class _TrackingFile(abc.ABC):
150+
class _Writer(abc.ABC):
151151
@staticmethod
152-
def create(file: str | None) -> _TrackingFile | None:
152+
def create(file: str | None) -> _Writer | None:
153153
if file is not None:
154154
if file.endswith('.csv'):
155-
return _CSVTrackingFile(file)
155+
return _CSVWriter(file)
156156
elif file.endswith('.sqlite'):
157-
return _SQLiteTrackingFile(file)
157+
return _SQLiteWriter(file)
158158
else:
159159
raise ValueError(
160160
f'Invalid file name: "{file}". Valid file extensions are ".csv" and ".sqlite".')
@@ -164,7 +164,7 @@ def create(file: str | None) -> _TrackingFile | None:
164164
def __init__(self, file: str):
165165
self._file = file
166166

167-
def write_row(self, values: _TimepointUsage | _SubTrackerLog):
167+
def write_row(self, values: object):
168168
values = dclass.asdict(values)
169169
if not os.path.isfile(self._file):
170170
self._create_file(values)
@@ -179,7 +179,7 @@ def _create_file(self, values: dict):
179179
pass # pragma: nocover
180180

181181

182-
class _CSVTrackingFile(_TrackingFile):
182+
class _CSVWriter(_Writer):
183183
def _write_row(self, values: dict):
184184
with open(self._file, 'a', newline='') as f:
185185
writer = csv.DictWriter(f, fieldnames=values.keys())
@@ -191,13 +191,14 @@ def _create_file(self, values: dict):
191191
writer.writeheader()
192192

193193

194-
class _SQLiteTrackingFile(_TrackingFile):
195-
_SQLITE_TABLE_NAME = 'tracking'
194+
class _SQLiteWriter(_Writer):
195+
_DATA_TABLE = 'data'
196+
_STATIC_DATA_TABLE = 'static_data'
196197

197198
def _write_row(self, values: dict):
198199
engine = sqlalc.create_engine(f'sqlite:///{self._file}', poolclass=sqlalc.pool.NullPool)
199200
metadata = sqlalc.MetaData()
200-
tracking_table = sqlalc.Table(_SQLiteTrackingFile._SQLITE_TABLE_NAME, metadata, autoload_with=engine)
201+
tracking_table = sqlalc.Table(_SQLiteWriter._DATA_TABLE, metadata, autoload_with=engine)
201202
Session = sqlorm.sessionmaker(bind=engine)
202203
with Session() as session:
203204
insert_stmt = sqlalc.insert(tracking_table).values(**values)
@@ -217,5 +218,5 @@ def _create_file(self, values: dict):
217218
for column_name, data_type in schema.items():
218219
sqlalchemy_type = type_mapping[data_type]
219220
columns.append(sqlalc.Column(column_name, sqlalchemy_type))
220-
sqlalc.Table(_SQLiteTrackingFile._SQLITE_TABLE_NAME, metadata, *columns)
221+
sqlalc.Table(_SQLiteWriter._DATA_TABLE, metadata, *columns)
221222
metadata.create_all(engine)

src/gpu_tracker/sub_tracker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import time
55
import functools
6-
from ._helper_classes import _TrackingFile, _SubTrackerLog
6+
from ._helper_classes import _Writer, _SubTrackerLog
77

88

99
class SubTracker:
@@ -21,7 +21,7 @@ def __init__(
2121
"""
2222
:param code_block_name: The name of the code block within a ``Tracker`` context that is being sub-tracked. Defaults to the file path followed by a colon followed by the ``code_block_attribute``.
2323
:param code_block_attribute: Only used if ``code_block_name`` is ``None``. Defaults to the line number where the SubTracker context is started.
24-
:param sub_tracking_file: The path to the file to log the time stamps of the code block being sub-tracked Defaults to the ID of the process where the SubTracker context is created and in CSV format.
24+
:param sub_tracking_file: The path to the file to log the time stamps of the code block being sub-tracked. Defaults to the ID of the process where the SubTracker context is created and in CSV format.
2525
"""
2626
if code_block_name is not None:
2727
self.code_block_name = code_block_name
@@ -34,7 +34,7 @@ def __init__(
3434
if sub_tracking_file is None:
3535
sub_tracking_file = f'{os.getpid()}.csv'
3636
self.sub_tracking_file = sub_tracking_file
37-
self._sub_tracking_file = _TrackingFile.create(self.sub_tracking_file)
37+
self._sub_tracking_file = _Writer.create(self.sub_tracking_file)
3838

3939
def _log(self, code_block_position: _SubTrackerLog.CodeBlockPosition):
4040
sub_tracker_log = _SubTrackerLog(
@@ -46,7 +46,7 @@ def __enter__(self):
4646
return self
4747

4848
def __exit__(self, *_):
49-
self._log(_SubTrackerLog.CodeBlockPosition.END)
49+
self._log(_SubTrackerLog.CodeBlockPosition.STOP)
5050

5151

5252
def sub_track(code_block_name: str | None = None, code_block_attribute: str | None = None, sub_tracking_file: str | None = None):

src/gpu_tracker/tracker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pickle as pkl
1414
import uuid
1515
import pandas as pd
16-
from ._helper_classes import _NvidiaQuerier, _AMDQuerier, _TrackingFile, _TimepointUsage
16+
from ._helper_classes import _NvidiaQuerier, _AMDQuerier, _Writer, _TimepointUsage
1717

1818

1919
class _TrackingProcess(mproc.Process):
@@ -64,7 +64,7 @@ def __init__(
6464
self._is_linux = platform.system().lower() == 'linux'
6565
cannot_connect_warning = ('The {} command is installed but cannot connect to a GPU. '
6666
'The GPU RAM and GPU utilization values will remain 0.0.')
67-
self.tracking_file = _TrackingFile.create(tracking_file)
67+
self.tracking_file = _Writer.create(tracking_file)
6868
if gpu_brand is None:
6969
nvidia_available = _NvidiaQuerier.is_available()
7070
nvidia_installed = nvidia_available is not None
@@ -349,7 +349,7 @@ def __init__(
349349
:param n_join_attempts: The number of times the tracker attempts to join its underlying sub-process.
350350
:param join_timeout: The amount of time the tracker waits for its underlying sub-process to join.
351351
:param gpu_brand: The brand of GPU to profile. Valid values are "nvidia" and "amd". Defaults to the brand of GPU detected in the system, checking Nvidia first.
352-
:param tracking_file: If specified, stores the individual resource usage measurements at each iteration. Valid file formats are CSV (.csv) and SQLite (.sqlite) where the SQLite file format stores the data in a table called "tracking" and allows for more efficient querying.
352+
:param tracking_file: If specified, stores the individual resource usage measurements at each iteration. Valid file formats are CSV (.csv) and SQLite (.sqlite) where the SQLite file format stores the data in a table called "data" and allows for more efficient querying.
353353
:raises ValueError: Raised if invalid arguments are provided.
354354
"""
355355
current_process_id = os.getpid()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
position,timestamp
22
START,12
3-
END,13
3+
STOP,13

tests/data/decorated-function.csv

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
position,timestamp
22
START,0
3-
END,1
3+
STOP,1
44
START,2
5-
END,3
5+
STOP,3
66
START,4
7-
END,5
7+
STOP,5
88
START,6
9-
END,7
9+
STOP,7
1010
START,8
11-
END,9
11+
STOP,9
1212
START,10
13-
END,11
13+
STOP,11

tests/data/sub-tracker.csv

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
position,timestamp
22
START,0
3-
END,1
3+
STOP,1
44
START,2
5-
END,3
5+
STOP,3
66
START,4
7-
END,5
7+
STOP,5
88
START,6
9-
END,7
9+
STOP,7
1010
START,8
11-
END,9
11+
STOP,9

tests/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sqlalchemy as sqlalc
33
import os
44
# noinspection PyProtectedMember
5-
from gpu_tracker._helper_classes import _SQLiteTrackingFile
5+
from gpu_tracker._helper_classes import _SQLiteWriter
66
import gpu_tracker as gput
77

88

@@ -17,7 +17,7 @@ def test_tracking_file(
1717
actual_tracking_log = pd.read_csv(actual_tracking_file)
1818
else:
1919
engine = sqlalc.create_engine(f'sqlite:///{actual_tracking_file}', poolclass=sqlalc.pool.NullPool)
20-
actual_tracking_log = pd.read_sql_table(_SQLiteTrackingFile._SQLITE_TABLE_NAME, engine)
20+
actual_tracking_log = pd.read_sql_table(_SQLiteWriter._DATA_TABLE, engine)
2121
if excluded_col is not None:
2222
actual_tracking_log[excluded_col].apply(excluded_col_test)
2323
actual_tracking_log = actual_tracking_log[actual_tracking_log.columns.difference([excluded_col])]

0 commit comments

Comments
 (0)