Skip to content

Commit 43e7baa

Browse files
committed
Linter
1 parent d6ff94b commit 43e7baa

File tree

3 files changed

+196
-120
lines changed

3 files changed

+196
-120
lines changed

label_studio_ml/examples/timeseries_segmenter/_wsgi.py

Lines changed: 78 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,30 @@
44
import logging
55
import logging.config
66

7-
logging.config.dictConfig({
8-
"version": 1,
9-
"disable_existing_loggers": False,
10-
"formatters": {
11-
"standard": {
12-
"format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s"
7+
logging.config.dictConfig(
8+
{
9+
'version': 1,
10+
'disable_existing_loggers': False,
11+
'formatters': {
12+
'standard': {
13+
'format': '[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s'
14+
}
15+
},
16+
'handlers': {
17+
'console': {
18+
'class': 'logging.StreamHandler',
19+
'level': os.getenv('LOG_LEVEL'),
20+
'stream': 'ext://sys.stdout',
21+
'formatter': 'standard',
22+
}
23+
},
24+
'root': {
25+
'level': os.getenv('LOG_LEVEL'),
26+
'handlers': ['console'],
27+
'propagate': True,
28+
},
1329
}
14-
},
15-
"handlers": {
16-
"console": {
17-
"class": "logging.StreamHandler",
18-
"level": os.getenv('LOG_LEVEL'),
19-
"stream": "ext://sys.stdout",
20-
"formatter": "standard"
21-
}
22-
},
23-
"root": {
24-
"level": os.getenv('LOG_LEVEL'),
25-
"handlers": [
26-
"console"
27-
],
28-
"propagate": True
29-
}
30-
})
30+
)
3131

3232
from label_studio_ml.api import init_app
3333
from model import TimeSeriesSegmenter
@@ -45,37 +45,61 @@ def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
4545
return config
4646

4747

48-
if __name__ == "__main__":
48+
if __name__ == '__main__':
4949
parser = argparse.ArgumentParser(description='Label studio')
5050
parser.add_argument(
51-
'-p', '--port', dest='port', type=int, default=9090,
52-
help='Server port')
51+
'-p', '--port', dest='port', type=int, default=9090, help='Server port'
52+
)
53+
parser.add_argument(
54+
'--host', dest='host', type=str, default='0.0.0.0', help='Server host'
55+
)
5356
parser.add_argument(
54-
'--host', dest='host', type=str, default='0.0.0.0',
55-
help='Server host')
57+
'--kwargs',
58+
'--with',
59+
dest='kwargs',
60+
metavar='KEY=VAL',
61+
nargs='+',
62+
type=lambda kv: kv.split('='),
63+
help='Additional LabelStudioMLBase model initialization kwargs',
64+
)
5665
parser.add_argument(
57-
'--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='),
58-
help='Additional LabelStudioMLBase model initialization kwargs')
66+
'-d',
67+
'--debug',
68+
dest='debug',
69+
action='store_true',
70+
help='Switch debug mode',
71+
)
5972
parser.add_argument(
60-
'-d', '--debug', dest='debug', action='store_true',
61-
help='Switch debug mode')
73+
'--log-level',
74+
dest='log_level',
75+
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
76+
default=None,
77+
help='Logging level',
78+
)
6279
parser.add_argument(
63-
'--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None,
64-
help='Logging level')
80+
'--model-dir',
81+
dest='model_dir',
82+
default=os.path.dirname(__file__),
83+
help='Directory where models are stored (relative to the project directory)',
84+
)
6585
parser.add_argument(
66-
'--model-dir', dest='model_dir', default=os.path.dirname(__file__),
67-
help='Directory where models are stored (relative to the project directory)')
86+
'--check',
87+
dest='check',
88+
action='store_true',
89+
help='Validate model instance before launching server',
90+
)
6891
parser.add_argument(
69-
'--check', dest='check', action='store_true',
70-
help='Validate model instance before launching server')
71-
parser.add_argument('--basic-auth-user',
72-
default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None),
73-
help='Basic auth user')
74-
75-
parser.add_argument('--basic-auth-pass',
76-
default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None),
77-
help='Basic auth pass')
78-
92+
'--basic-auth-user',
93+
default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None),
94+
help='Basic auth user',
95+
)
96+
97+
parser.add_argument(
98+
'--basic-auth-pass',
99+
default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None),
100+
help='Basic auth pass',
101+
)
102+
79103
args = parser.parse_args()
80104

81105
# setup logging level
@@ -110,10 +134,16 @@ def parse_kwargs():
110134
kwargs.update(parse_kwargs())
111135

112136
if args.check:
113-
print('Check "' + TimeSeriesSegmenter.__name__ + '" instance creation..')
137+
print(
138+
'Check "' + TimeSeriesSegmenter.__name__ + '" instance creation..'
139+
)
114140
model = TimeSeriesSegmenter(**kwargs)
115141

116-
app = init_app(model_class=TimeSeriesSegmenter, basic_auth_user=args.basic_auth_user, basic_auth_pass=args.basic_auth_pass)
142+
app = init_app(
143+
model_class=TimeSeriesSegmenter,
144+
basic_auth_user=args.basic_auth_user,
145+
basic_auth_pass=args.basic_auth_pass,
146+
)
117147

118148
app.run(host=args.host, port=args.port, debug=args.debug)
119149

label_studio_ml/examples/timeseries_segmenter/model.py

Lines changed: 59 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ class TimeSeriesSegmenter(LabelStudioMLBase):
3131

3232
LABEL_STUDIO_HOST = os.getenv('LABEL_STUDIO_HOST', 'http://localhost:8080')
3333
LABEL_STUDIO_API_KEY = os.getenv('LABEL_STUDIO_API_KEY')
34-
START_TRAINING_EACH_N_UPDATES = int(os.getenv('START_TRAINING_EACH_N_UPDATES', 10))
34+
START_TRAINING_EACH_N_UPDATES = int(
35+
os.getenv('START_TRAINING_EACH_N_UPDATES', 10)
36+
)
3537
MODEL_DIR = os.getenv('MODEL_DIR', '.')
3638

3739
def setup(self):
3840
"""Initialize model metadata."""
39-
self.set("model_version", f"{self.__class__.__name__}-v0.0.1")
41+
self.set('model_version', f'{self.__class__.__name__}-v0.0.1')
4042

4143
# ------------------------------------------------------------------
4244
# Utility helpers
@@ -46,49 +48,55 @@ def _get_model(self, blank: bool = False) -> LogisticRegression:
4648
global _model
4749
if _model is not None and not blank:
4850
return _model
49-
50-
model_path = os.path.join(self.MODEL_DIR, "model.pkl")
51+
52+
model_path = os.path.join(self.MODEL_DIR, 'model.pkl')
5153
if not blank and os.path.exists(model_path):
52-
with open(model_path, "rb") as f:
54+
with open(model_path, 'rb') as f:
5355
_model = pickle.load(f)
5456
else:
5557
_model = LogisticRegression(max_iter=1000)
5658
return _model
5759

5860
def _get_labeling_params(self) -> Dict:
5961
"""Return tag names and channel information from the labeling config."""
60-
from_name, to_name, value = self.label_interface.get_first_tag_occurence(
61-
"TimeSeriesLabels", "TimeSeries"
62+
(
63+
from_name,
64+
to_name,
65+
value,
66+
) = self.label_interface.get_first_tag_occurence(
67+
'TimeSeriesLabels', 'TimeSeries'
6268
)
6369
tag = self.label_interface.get_tag(from_name)
6470
labels = list(tag.labels)
6571
ts_tag = self.label_interface.get_tag(to_name)
66-
time_col = ts_tag.attr.get("timeColumn")
72+
time_col = ts_tag.attr.get('timeColumn')
6773
# Parse channel names from the original XML because TimeSeries tag
6874
# does not expose its children via label-studio's interface
6975
import xml.etree.ElementTree as ET
7076

7177
root = ET.fromstring(self.label_config)
7278
ts_elem = root.find(f".//TimeSeries[@name='{to_name}']")
73-
channels = [ch.attrib["column"] for ch in ts_elem.findall("Channel")]
79+
channels = [ch.attrib['column'] for ch in ts_elem.findall('Channel')]
7480

7581
return {
7682
'from_name': from_name,
7783
'to_name': to_name,
7884
'value': value,
7985
'labels': labels,
8086
'time_col': time_col,
81-
'channels': channels
87+
'channels': channels,
8288
}
8389

8490
def _read_csv(self, task: Dict, path: str) -> pd.DataFrame:
8591
"""Load a CSV referenced by the task using Label Studio utilities."""
8692
csv_str = self.preload_task_data(task, path)
8793
return pd.read_csv(io.StringIO(csv_str))
8894

89-
def _predict_task(self, task: Dict, model: LogisticRegression, params: Dict) -> Dict:
95+
def _predict_task(
96+
self, task: Dict, model: LogisticRegression, params: Dict
97+
) -> Dict:
9098
"""Return Label Studio-style prediction for a single task."""
91-
df = self._read_csv(task, task["data"][params["value"]])
99+
df = self._read_csv(task, task['data'][params['value']])
92100

93101
# Vector of sensor values per row
94102
X = df[params['channels']].values
@@ -108,26 +116,28 @@ def _predict_task(self, task: Dict, model: LogisticRegression, params: Dict) ->
108116
for seg in segments:
109117
score = float(np.mean(seg['scores']))
110118
avg_score += score
111-
results.append({
112-
'from_name': params['from_name'],
113-
'to_name': params['to_name'],
114-
'type': 'timeserieslabels',
115-
'value': {
116-
'start': seg['start'],
117-
'end': seg['end'],
118-
'instant': False,
119-
'timeserieslabels': [seg['label']]
120-
},
121-
'score': score
122-
})
119+
results.append(
120+
{
121+
'from_name': params['from_name'],
122+
'to_name': params['to_name'],
123+
'type': 'timeserieslabels',
124+
'value': {
125+
'start': seg['start'],
126+
'end': seg['end'],
127+
'instant': False,
128+
'timeserieslabels': [seg['label']],
129+
},
130+
'score': score,
131+
}
132+
)
123133

124134
if not results:
125135
return {}
126136

127137
return {
128138
'result': results,
129139
'score': avg_score / len(results),
130-
'model_version': self.get('model_version')
140+
'model_version': self.get('model_version'),
131141
}
132142

133143
def _group_rows(self, df: pd.DataFrame, time_col: str) -> List[Dict]:
@@ -146,13 +156,15 @@ def _group_rows(self, df: pd.DataFrame, time_col: str) -> List[Dict]:
146156
'label': label,
147157
'start': row[time_col],
148158
'end': row[time_col],
149-
'scores': [row['score']]
159+
'scores': [row['score']],
150160
}
151161
if current:
152162
segments.append(current)
153163
return segments
154164

155-
def _collect_samples(self, tasks: List[Dict], params: Dict, label2idx: Dict[str, int]) -> Tuple[List, List]:
165+
def _collect_samples(
166+
self, tasks: List[Dict], params: Dict, label2idx: Dict[str, int]
167+
) -> Tuple[List, List]:
156168
"""Return feature matrix and label vector built from all labeled tasks."""
157169
X, y = [], []
158170
for task in tasks:
@@ -169,9 +181,8 @@ def _collect_samples(self, tasks: List[Dict], params: Dict, label2idx: Dict[str,
169181
start = r['value']['start']
170182
end = r['value']['end']
171183
label = r['value']['timeserieslabels'][0]
172-
mask = (
173-
(df[params['time_col']] >= start)
174-
& (df[params['time_col']] <= end)
184+
mask = (df[params['time_col']] >= start) & (
185+
df[params['time_col']] <= end
175186
)
176187
seg = df.loc[mask, params['channels']].values
177188
X.extend(seg)
@@ -191,9 +202,13 @@ def predict(
191202
"""Return time series segments predicted for the given tasks."""
192203
params = self._get_labeling_params()
193204
model = self._get_model()
194-
predictions = [self._predict_task(task, model, params) for task in tasks]
205+
predictions = [
206+
self._predict_task(task, model, params) for task in tasks
207+
]
195208

196-
return ModelResponse(predictions=predictions, model_version=self.get('model_version'))
209+
return ModelResponse(
210+
predictions=predictions, model_version=self.get('model_version')
211+
)
197212

198213
def _get_tasks(self, project_id: int) -> List[Dict]:
199214
"""Fetch labeled tasks from Label Studio."""
@@ -206,20 +221,24 @@ def _get_tasks(self, project_id: int) -> List[Dict]:
206221
def fit(self, event, data, **kwargs):
207222
"""Train the model on all labeled segments."""
208223
if event not in (
209-
"ANNOTATION_CREATED",
210-
"ANNOTATION_UPDATED",
211-
"START_TRAINING",
224+
'ANNOTATION_CREATED',
225+
'ANNOTATION_UPDATED',
226+
'START_TRAINING',
212227
):
213-
logger.info("Skip training: event %s is not supported", event)
228+
logger.info('Skip training: event %s is not supported', event)
214229
return
215-
230+
216231
project_id = data['annotation']['project']
217232
tasks = self._get_tasks(project_id)
218-
if len(tasks) % self.START_TRAINING_EACH_N_UPDATES != 0 and event != 'START_TRAINING':
233+
if (
234+
len(tasks) % self.START_TRAINING_EACH_N_UPDATES != 0
235+
and event != 'START_TRAINING'
236+
):
219237
logger.info(
220-
f'Skip training: {len(tasks)} tasks are not multiple of {self.START_TRAINING_EACH_N_UPDATES}')
238+
f'Skip training: {len(tasks)} tasks are not multiple of {self.START_TRAINING_EACH_N_UPDATES}'
239+
)
221240
return
222-
241+
223242
params = self._get_labeling_params()
224243
label2idx = {l: i for i, l in enumerate(params['labels'])}
225244

@@ -234,4 +253,3 @@ def fit(self, event, data, **kwargs):
234253
global _model
235254
_model = None # reload on next predict
236255
self._get_model()
237-

0 commit comments

Comments
 (0)