Skip to content

Commit c6187be

Browse files
authored
Fix/branch (#1326)
* Batch upload for scalar metrics in uploader Added batching to the upload_scalar_metrics function to split uploads into chunks of configurable size (default 3000). This improves reliability and performance when uploading large numbers of scalar metrics. * Refactor metric upload batching and type hints Introduced MetricDict TypedDict for clearer type hints and refactored batching logic for metric, scalar, and column uploads into trace_metrics. Now batching is handled centrally with per_request_len parameter, simplifying upload_scalar_metrics and upload_columns. Added assertions for project and experiment IDs and improved error handling for column uploads.
1 parent b558071 commit c6187be

File tree

1 file changed

+71
-31
lines changed

1 file changed

+71
-31
lines changed

swanlab/core_python/uploader/upload.py

Lines changed: 71 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
@description: 定义上传函数
66
"""
77

8-
from typing import List, Union, Literal
8+
from typing import List, Union, Literal, TypedDict
99

1010
from swanlab.log import swanlog
1111
from .model import ColumnModel, MediaModel, ScalarModel, FileModel, LogModel
@@ -15,15 +15,25 @@
1515
house_url = '/house/metrics'
1616

1717

18-
def create_data(metrics: List[dict], metrics_type: str) -> dict:
18+
# 上传指标数据
19+
class MetricDict(TypedDict):
20+
projectId: str
21+
experimentId: str
22+
type: str
23+
metrics: List[dict]
24+
flagId: Union[str, None]
25+
26+
27+
def create_data(metrics: List[dict], metrics_type: str) -> MetricDict:
1928
"""
2029
携带上传日志的指标信息
2130
"""
2231
client = get_client()
2332
# Move 等实验需要将数据上传到根实验上
2433
exp_id = client.exp.root_exp_cuid or client.exp.cuid
2534
proj_id = client.exp.root_proj_cuid or client.proj.cuid
26-
35+
assert proj_id is not None, "Project ID is empty."
36+
assert exp_id is not None, "Experiment ID is empty."
2737
flag_id = client.exp.flag_id
2838
return {
2939
"projectId": proj_id,
@@ -34,19 +44,50 @@ def create_data(metrics: List[dict], metrics_type: str) -> dict:
3444
}
3545

3646

37-
def trace_metrics(url: str, data: Union[dict, list] = None, method: Literal['post', 'put'] = 'post'):
47+
def trace_metrics(
48+
url: str,
49+
data: Union[MetricDict, list] = None,
50+
method: Literal['post', 'put'] = 'post',
51+
per_request_len: int = 5000,
52+
):
3853
"""
3954
创建指标数据方法,如果 client 处于挂起状态,则不进行上传
4055
:param url: 上传的URL地址
4156
:param data: 上传的数据,可以是字典或列表
4257
:param method: 请求方法,默认为 'post'
58+
:param per_request_len: 每次请求的最大数据长度,如果设置为-1则不进行分批上传
4359
"""
60+
# TODO 用装饰器设置client的pending状态
4461
client = get_client()
4562
if client.pending:
4663
return
47-
_, resp = getattr(client, method)(url, data)
48-
if resp.status_code == 202:
49-
client.pending = True
64+
if per_request_len == -1:
65+
_, resp = getattr(client, method)(url, data)
66+
if resp.status_code == 202:
67+
client.pending = True
68+
return
69+
return
70+
# 分批上传
71+
if isinstance(data, dict):
72+
# 1. 指标数据
73+
for i in range(0, len(data['metrics']), per_request_len):
74+
_, resp = getattr(client, method)(
75+
url,
76+
{
77+
**data,
78+
"metrics": data['metrics'][i : i + per_request_len],
79+
},
80+
)
81+
if resp.status_code == 202:
82+
client.pending = True
83+
return
84+
else:
85+
# 2. 列表数据(列等)
86+
for i in range(0, len(data), per_request_len):
87+
_, resp = getattr(client, method)(url, data[i : i + per_request_len])
88+
if resp.status_code == 202:
89+
client.pending = True
90+
return
5091

5192

5293
@sync_error_handler
@@ -106,39 +147,38 @@ def upload_files(files: List[FileModel]):
106147
if file_model.empty:
107148
return
108149
data = file_model.to_dict()
109-
trace_metrics(f'/project/{http.groupname}/{http.projname}/runs/{http.exp_id}/profile', data, method="put")
150+
trace_metrics(
151+
f'/project/{http.groupname}/{http.projname}/runs/{http.exp_id}/profile',
152+
data,
153+
method="put",
154+
per_request_len=-1,
155+
)
110156
return
111157

112158

113159
@sync_error_handler
114-
def upload_columns(columns: List[ColumnModel], per_request_len: int = 3000):
160+
def upload_columns(columns: List[ColumnModel]):
115161
"""
116162
批量上传并创建 columns,每个请求的列长度有一个最大值
117163
"""
118164
http = get_client()
119165
url = f'/experiment/{http.exp_id}/columns'
120-
# 将columns拆分成多个小的列表,每个列表的长度不能超过单个请求的最大长度
121-
columns_list = []
122-
columns_count = len(columns)
123-
for i in range(0, columns_count, per_request_len):
124-
columns_list.append([columns[i + j].to_dict() for j in range(min(per_request_len, columns_count - i))])
125-
# 上传每个列表
126-
for columns in columns_list:
127-
# 如果列表长度为0,则跳过
128-
if len(columns) == 0:
129-
continue
130-
try:
131-
trace_metrics(url, columns)
132-
except ApiError as e:
133-
# 处理实验不存在的异常
134-
if e.resp.status_code == 404:
135-
resp = decode_response(e.resp)
136-
# 实验不存在,那么剩下的列也没有必要上传了,直接返回
137-
if isinstance(resp, dict) and resp.get('code') == 'Disabled_Resource':
138-
swanlog.warning(f"Experiment {http.exp_id} has been deleted, skipping column upload.")
139-
return
140-
raise e
141-
return
166+
# 如果列表长度为0,则跳过
167+
if len(columns) == 0:
168+
swanlog.debug("No columns to upload.")
169+
return
170+
# 分批上传
171+
try:
172+
trace_metrics(url, [x.to_dict() for x in columns], per_request_len=3000)
173+
except ApiError as e:
174+
# 处理实验不存在的异常
175+
if e.resp.status_code == 404:
176+
resp = decode_response(e.resp)
177+
# 实验不存在,那么剩下的列也没有必要上传了,直接返回
178+
if isinstance(resp, dict) and resp.get('code') == 'Disabled_Resource':
179+
swanlog.warning(f"Experiment {http.exp_id} has been deleted, skipping column upload.")
180+
return
181+
raise e
142182

143183

144184
__all__ = [

0 commit comments

Comments
 (0)