55@description: 定义上传函数
66"""
77
8- from typing import List , Union , Literal
8+ from typing import List , Union , Literal , TypedDict
99
1010from swanlab .log import swanlog
1111from .model import ColumnModel , MediaModel , ScalarModel , FileModel , LogModel
1515house_url = '/house/metrics'
1616
1717
18- def create_data (metrics : List [dict ], metrics_type : str ) -> dict :
18+ # 上传指标数据
19+ class MetricDict (TypedDict ):
20+ projectId : str
21+ experimentId : str
22+ type : str
23+ metrics : List [dict ]
24+ flagId : Union [str , None ]
25+
26+
27+ def create_data (metrics : List [dict ], metrics_type : str ) -> MetricDict :
1928 """
2029 携带上传日志的指标信息
2130 """
2231 client = get_client ()
2332 # Move 等实验需要将数据上传到根实验上
2433 exp_id = client .exp .root_exp_cuid or client .exp .cuid
2534 proj_id = client .exp .root_proj_cuid or client .proj .cuid
26-
35+ assert proj_id is not None , "Project ID is empty."
36+ assert exp_id is not None , "Experiment ID is empty."
2737 flag_id = client .exp .flag_id
2838 return {
2939 "projectId" : proj_id ,
@@ -34,19 +44,50 @@ def create_data(metrics: List[dict], metrics_type: str) -> dict:
3444 }
3545
3646
37- def trace_metrics (url : str , data : Union [dict , list ] = None , method : Literal ['post' , 'put' ] = 'post' ):
47+ def trace_metrics (
48+ url : str ,
49+ data : Union [MetricDict , list ] = None ,
50+ method : Literal ['post' , 'put' ] = 'post' ,
51+ per_request_len : int = 5000 ,
52+ ):
3853 """
3954 创建指标数据方法,如果 client 处于挂起状态,则不进行上传
4055 :param url: 上传的URL地址
4156 :param data: 上传的数据,可以是字典或列表
4257 :param method: 请求方法,默认为 'post'
58+ :param per_request_len: 每次请求的最大数据长度,如果设置为-1则不进行分批上传
4359 """
60+ # TODO 用装饰器设置client的pending状态
4461 client = get_client ()
4562 if client .pending :
4663 return
47- _ , resp = getattr (client , method )(url , data )
48- if resp .status_code == 202 :
49- client .pending = True
64+ if per_request_len == - 1 :
65+ _ , resp = getattr (client , method )(url , data )
66+ if resp .status_code == 202 :
67+ client .pending = True
68+ return
69+ return
70+ # 分批上传
71+ if isinstance (data , dict ):
72+ # 1. 指标数据
73+ for i in range (0 , len (data ['metrics' ]), per_request_len ):
74+ _ , resp = getattr (client , method )(
75+ url ,
76+ {
77+ ** data ,
78+ "metrics" : data ['metrics' ][i : i + per_request_len ],
79+ },
80+ )
81+ if resp .status_code == 202 :
82+ client .pending = True
83+ return
84+ else :
85+ # 2. 列表数据(列等)
86+ for i in range (0 , len (data ), per_request_len ):
87+ _ , resp = getattr (client , method )(url , data [i : i + per_request_len ])
88+ if resp .status_code == 202 :
89+ client .pending = True
90+ return
5091
5192
5293@sync_error_handler
@@ -106,39 +147,38 @@ def upload_files(files: List[FileModel]):
106147 if file_model .empty :
107148 return
108149 data = file_model .to_dict ()
109- trace_metrics (f'/project/{ http .groupname } /{ http .projname } /runs/{ http .exp_id } /profile' , data , method = "put" )
150+ trace_metrics (
151+ f'/project/{ http .groupname } /{ http .projname } /runs/{ http .exp_id } /profile' ,
152+ data ,
153+ method = "put" ,
154+ per_request_len = - 1 ,
155+ )
110156 return
111157
112158
113159@sync_error_handler
114- def upload_columns (columns : List [ColumnModel ], per_request_len : int = 3000 ):
160+ def upload_columns (columns : List [ColumnModel ]):
115161 """
116162 批量上传并创建 columns,每个请求的列长度有一个最大值
117163 """
118164 http = get_client ()
119165 url = f'/experiment/{ http .exp_id } /columns'
120- # 将columns拆分成多个小的列表,每个列表的长度不能超过单个请求的最大长度
121- columns_list = []
122- columns_count = len (columns )
123- for i in range (0 , columns_count , per_request_len ):
124- columns_list .append ([columns [i + j ].to_dict () for j in range (min (per_request_len , columns_count - i ))])
125- # 上传每个列表
126- for columns in columns_list :
127- # 如果列表长度为0,则跳过
128- if len (columns ) == 0 :
129- continue
130- try :
131- trace_metrics (url , columns )
132- except ApiError as e :
133- # 处理实验不存在的异常
134- if e .resp .status_code == 404 :
135- resp = decode_response (e .resp )
136- # 实验不存在,那么剩下的列也没有必要上传了,直接返回
137- if isinstance (resp , dict ) and resp .get ('code' ) == 'Disabled_Resource' :
138- swanlog .warning (f"Experiment { http .exp_id } has been deleted, skipping column upload." )
139- return
140- raise e
141- return
166+ # 如果列表长度为0,则跳过
167+ if len (columns ) == 0 :
168+ swanlog .debug ("No columns to upload." )
169+ return
170+ # 分批上传
171+ try :
172+ trace_metrics (url , [x .to_dict () for x in columns ], per_request_len = 3000 )
173+ except ApiError as e :
174+ # 处理实验不存在的异常
175+ if e .resp .status_code == 404 :
176+ resp = decode_response (e .resp )
177+ # 实验不存在,那么剩下的列也没有必要上传了,直接返回
178+ if isinstance (resp , dict ) and resp .get ('code' ) == 'Disabled_Resource' :
179+ swanlog .warning (f"Experiment { http .exp_id } has been deleted, skipping column upload." )
180+ return
181+ raise e
142182
143183
144184__all__ = [
0 commit comments