1+ from copy import deepcopy
12from typing import Any
23
34import numpy as np
@@ -18,7 +19,7 @@ def start_run(self, **kwargs: dict[str, Any]) -> None:
1819 self .history .append (("start_run" , kwargs ))
1920
2021 def start_module (self , ** kwargs : dict [str , Any ]) -> None :
21- self .history .append (("start_module" , kwargs ))
22+ self .history .append (("start_module" , deepcopy ( kwargs ) ))
2223
2324 def log_value (self , ** kwargs : dict [str , Any ]) -> None :
2425 self .history .append (("log_value" , kwargs ))
@@ -52,7 +53,7 @@ def test_pipeline_callbacks(dataset):
5253 {
5354 "module_name" : "retrieval" ,
5455 "k" : [5 , 10 ],
55- "embedder_name " : ["sergeyzh/rubert-tiny-turbo" ],
56+ "embedder_config " : ["sergeyzh/rubert-tiny-turbo" ],
5657 }
5758 ],
5859 },
@@ -98,88 +99,39 @@ def test_pipeline_callbacks(dataset):
9899 (
99100 "start_module" ,
100101 {
102+ "module_kwargs" : {"embedder_config" : "sergeyzh/rubert-tiny-turbo" , "k" : 5 },
101103 "module_name" : "retrieval" ,
102104 "num" : 0 ,
103- "module_kwargs" : {"k" : 5 , "embedder_name" : "sergeyzh/rubert-tiny-turbo" },
104- },
105- ),
106- (
107- "log_metric" ,
108- {
109- "metrics" : {
110- "retrieval_hit_rate" : 1.0 ,
111- }
112105 },
113106 ),
107+ ("log_metric" , {"metrics" : {"retrieval_hit_rate" : 1.0 }}),
114108 ("end_module" , {}),
115109 (
116110 "start_module" ,
117111 {
112+ "module_kwargs" : {"embedder_config" : "sergeyzh/rubert-tiny-turbo" , "k" : 10 },
118113 "module_name" : "retrieval" ,
119114 "num" : 1 ,
120- "module_kwargs" : {"k" : 10 , "embedder_name" : "sergeyzh/rubert-tiny-turbo" },
121- },
122- ),
123- (
124- "log_metric" ,
125- {
126- "metrics" : {
127- "retrieval_hit_rate" : 1.0 ,
128- }
129115 },
130116 ),
117+ ("log_metric" , {"metrics" : {"retrieval_hit_rate" : 1.0 }}),
131118 ("end_module" , {}),
132119 (
133120 "start_module" ,
134- {
135- "module_name" : "knn" ,
136- "num" : 0 ,
137- "module_kwargs" : {"k" : 1 , "weights" : "uniform" , "embedder_name" : "sergeyzh/rubert-tiny-turbo" },
138- },
139- ),
140- (
141- "log_metric" ,
142- {
143- "metrics" : {
144- "scoring_accuracy" : 1.0 ,
145- "scoring_roc_auc" : 1.0 ,
146- }
147- },
121+ {"module_kwargs" : {"embedder_config" : None , "k" : 1 , "weights" : "uniform" }, "module_name" : "knn" , "num" : 0 },
148122 ),
123+ ("log_metric" , {"metrics" : {"scoring_accuracy" : 1.0 , "scoring_roc_auc" : 1.0 }}),
149124 ("end_module" , {}),
150125 (
151126 "start_module" ,
152- {
153- "module_name" : "knn" ,
154- "num" : 1 ,
155- "module_kwargs" : {"k" : 1 , "weights" : "distance" , "embedder_name" : "sergeyzh/rubert-tiny-turbo" },
156- },
157- ),
158- (
159- "log_metric" ,
160- {
161- "metrics" : {
162- "scoring_accuracy" : 1.0 ,
163- "scoring_roc_auc" : 1.0 ,
164- }
165- },
127+ {"module_kwargs" : {"embedder_config" : None , "k" : 1 , "weights" : "distance" }, "module_name" : "knn" , "num" : 1 },
166128 ),
129+ ("log_metric" , {"metrics" : {"scoring_accuracy" : 1.0 , "scoring_roc_auc" : 1.0 }}),
167130 ("end_module" , {}),
168- (
169- "start_module" ,
170- {"module_name" : "linear" , "num" : 0 , "module_kwargs" : {"embedder_name" : "sergeyzh/rubert-tiny-turbo" }},
171- ),
172- (
173- "log_metric" ,
174- {
175- "metrics" : {
176- "scoring_accuracy" : 0.75 ,
177- "scoring_roc_auc" : 1.0 ,
178- }
179- },
180- ),
131+ ("start_module" , {"module_kwargs" : {"embedder_config" : None }, "module_name" : "linear" , "num" : 0 }),
132+ ("log_metric" , {"metrics" : {"scoring_accuracy" : 0.75 , "scoring_roc_auc" : 1.0 }}),
181133 ("end_module" , {}),
182- ("start_module" , {"module_name " : "threshold" , "num" : 0 , "module_kwargs " : { "thresh" : 0.5 } }),
134+ ("start_module" , {"module_kwargs " : { "thresh" : 0.5 } , "module_name " : "threshold" , "num" : 0 }),
183135 (
184136 "log_metric" ,
185137 {
@@ -193,7 +145,7 @@ def test_pipeline_callbacks(dataset):
193145 },
194146 ),
195147 ("end_module" , {}),
196- ("start_module" , {"module_name " : "argmax" , "num " : 0 , "module_kwargs " : {} }),
148+ ("start_module" , {"module_kwargs " : {} , "module_name " : "argmax" , "num " : 0 }),
197149 (
198150 "log_metric" ,
199151 {
0 commit comments