@@ -71,11 +71,11 @@ def cmp_plot(data, input_features, metrics, labels, output_dir):
7171 print ("Ploting graphs DONE" )
7272
7373
74- @enter_decorate ("CMP EXPORT TBALE " , filename = TABLE_NAME )
74+ @enter_decorate ("CMP EXPORT TABLE " , filename = TABLE_NAME )
7575def cmp_export_table (
7676 all_clients_results : List [List [Dict ]],
7777 input_features : List [str ],
78- output_metrics : List [Dict ],
78+ output_metrics : List [str ],
7979 num_clients : int ,
8080 num_servers : int ,
8181 output_dir : str ,
@@ -84,45 +84,140 @@ def cmp_export_table(
8484 if not all_clients_results or not all_clients_results [0 ]:
8585 raise ValueError ("No data available to export." )
8686
87- if server_labels [0 ] is None :
87+ if server_labels is None or server_labels [0 ] is None :
8888 server_labels = [f"server_{ i + 1 } " for i in range (num_servers )]
8989
90- # header
91- header_cells = input_features + [" - " ]
92- for output_metric in output_metrics :
93- header_cells += [output_metric ] + [" - " ] * (len (server_labels ) - 1 )
94- header_row = "| " + " | " .join (map (str , header_cells )) + " |"
90+ # --- 1. 动态构建表头 ---
91+ # 将 input_features 组合成标题,例如: "Config (input_len / output_len / rate)"
92+ config_header_name = f"Config ({ ' / ' .join (input_features )} )"
9593
96- # sub header
97- sub_header_cells = [" - " ] * (len (input_features ) + 1 ) + server_labels * len (
98- output_metrics
99- )
100- sub_header_row = "| " + " | " .join (map (str , sub_header_cells )) + " |"
94+ header_cells = [config_header_name , "Metric" ] + server_labels
95+ if num_servers == 2 :
96+ header_cells .append ("Diff (%)" )
10197
98+ header_row = "| " + " | " .join (header_cells ) + " |"
10299 separator_row = "| " + " | " .join (["---" ] * len (header_cells )) + " |"
103- lines = [header_row , sub_header_row , separator_row ]
100+ lines = [header_row , separator_row ]
104101
102+ # --- 2. 遍历每一个配置 (Client Config) ---
105103 for client_idx in range (num_clients ):
106- #
107- row_values = []
108-
109- all_server_metrics = []
110- for server_idx in range (num_servers ):
111- server_metrics = []
112- idx = client_idx + server_idx * num_clients
113- row_results = all_clients_results [idx ]
114- if server_idx == 0 :
115- for feature in input_features :
116- row_values .append (f"{ row_results [0 ][feature ]:.2f} " )
117- row_values .append ("-" )
118- for metric in output_metrics :
119- server_metrics .append (avg_std_strf (metric , row_results , precision = 2 ))
120- all_server_metrics .append (server_metrics )
121-
122- for i in range (len (output_metrics )):
123- for j in range (num_servers ):
124- row_values .append (all_server_metrics [j ][i ])
125- lines .append ("| " + " | " .join (row_values ) + " |" )
126-
127- with open (os .path .join (output_dir , TABLE_NAME ), mode = "w" , encoding = "utf-8" ) as f :
104+
105+ # 动态提取当前配置下所有 feature 的值
106+ # 索引逻辑: client_idx 对应第一个 server 的该配置结果
107+ first_server_res_list = all_clients_results [client_idx ]
108+ first_sample = first_server_res_list [0 ]
109+
110+ config_val_list = []
111+ for feat in input_features :
112+ val = first_sample .get (feat , "N/A" )
113+ # 格式化数值:如果是浮点数保留两位,否则转字符串
114+ if isinstance (val , float ):
115+ config_val_list .append (f"{ val :.2f} " )
116+ else :
117+ config_val_list .append (str (val ))
118+
119+ # 拼接后的配置字符串,例如 "1200.00 / 800.00 / 4.00"
120+ config_str = " / " .join (config_val_list )
121+
122+ # --- 3. 遍历每一个指标 (Metric) ---
123+ for m_idx , metric in enumerate (output_metrics ):
124+ row_values = []
125+
126+ # 第一列:仅在指标块的第一行显示配置
127+ if m_idx == 0 :
128+ row_values .append (f"**{ config_str } **" )
129+ else :
130+ row_values .append (" " )
131+
132+ # 第二列:指标名称
133+ row_values .append (metric )
134+
135+ # 后面几列:各个 Server 的数值
136+ numerical_means = []
137+ for s_idx in range (num_servers ):
138+ idx = client_idx + s_idx * num_clients
139+ res_list = all_clients_results [idx ]
140+
141+ # 使用你原有的格式化函数获取 "均值 ± 标准差"
142+ display_str = avg_std_strf (metric , res_list , precision = 2 )
143+ row_values .append (display_str )
144+
145+ # 为计算 Diff 提取纯数值均值
146+ try :
147+ m_val = sum (r [metric ] for r in res_list ) / len (res_list )
148+ numerical_means .append (m_val )
149+ except :
150+ numerical_means .append (None )
151+
152+ # 最后一列:动态计算两个 Server 间的差异
153+ if num_servers == 2 :
154+ v1 , v2 = numerical_means [0 ], numerical_means [1 ]
155+ if v1 is not None and v2 is not None and v1 != 0 :
156+ diff = (v2 - v1 ) / v1 * 100
157+ row_values .append (f"{ diff :+.2f} %" )
158+ else :
159+ row_values .append ("-" )
160+
161+ lines .append ("| " + " | " .join (row_values ) + " |" )
162+
163+ # --- 4. 写入文件 ---
164+ output_path = os .path .join (output_dir , TABLE_NAME )
165+ with open (output_path , mode = "w" , encoding = "utf-8" ) as f :
128166 f .write ("\n " .join (lines ))
167+
168+
169+ # @enter_decorate("CMP EXPORT TBALE", filename=TABLE_NAME)
170+ # def cmp_export_table(
171+ # all_clients_results: List[List[Dict]],
172+ # input_features: List[str],
173+ # output_metrics: List[Dict],
174+ # num_clients: int,
175+ # num_servers: int,
176+ # output_dir: str,
177+ # server_labels: List[str],
178+ # ):
179+ # if not all_clients_results or not all_clients_results[0]:
180+ # raise ValueError("No data available to export.")
181+ #
182+ # if server_labels[0] is None:
183+ # server_labels = [f"server_{i + 1}" for i in range(num_servers)]
184+ #
185+ # # header
186+ # header_cells = input_features + [" - "]
187+ # for output_metric in output_metrics:
188+ # header_cells += [output_metric] + [" - "] * (len(server_labels) - 1)
189+ # header_row = "| " + " | ".join(map(str, header_cells)) + " |"
190+ #
191+ # # sub header
192+ # sub_header_cells = [" - "] * (len(input_features) + 1) + server_labels * len(
193+ # output_metrics
194+ # )
195+ # sub_header_row = "| " + " | ".join(map(str, sub_header_cells)) + " |"
196+ #
197+ # separator_row = "| " + " | ".join(["---"] * len(header_cells)) + " |"
198+ # lines = [header_row, sub_header_row, separator_row]
199+ #
200+ # for client_idx in range(num_clients):
201+ # #
202+ # row_values = []
203+ #
204+ # all_server_metrics = []
205+ # for server_idx in range(num_servers):
206+ # server_metrics = []
207+ # idx = client_idx + server_idx * num_clients
208+ # row_results = all_clients_results[idx]
209+ # if server_idx == 0:
210+ # for feature in input_features:
211+ # row_values.append(f"{row_results[0][feature]:.2f}")
212+ # row_values.append("-")
213+ # for metric in output_metrics:
214+ # server_metrics.append(avg_std_strf(metric, row_results, precision=2))
215+ # all_server_metrics.append(server_metrics)
216+ #
217+ # for i in range(len(output_metrics)):
218+ # for j in range(num_servers):
219+ # row_values.append(all_server_metrics[j][i])
220+ # lines.append("| " + " | ".join(row_values) + " |")
221+ #
222+ # with open(os.path.join(output_dir, TABLE_NAME), mode="w", encoding="utf-8") as f:
223+ # f.write("\n".join(lines))
0 commit comments