1
1
from __future__ import annotations
2
2
3
3
import asyncio
4
+ import math
5
+ import random
4
6
from abc import ABC , abstractmethod
5
7
from collections import defaultdict
6
8
from typing import TYPE_CHECKING
7
9
10
+ from rdagent .log import rdagent_logger as logger
11
+
8
12
if TYPE_CHECKING :
9
13
from rdagent .scenarios .data_science .proposal .exp_gen .base import DSTrace
10
14
@@ -22,7 +26,7 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
22
26
23
27
For proposing selections, we have to follow the rules
24
28
- Suggest selection: suggest a selection that is suitable for the current trace.
25
- - Suggested should be garenteed to be recorded at last!!!
29
+ - Suggested should be garenteed to be recorded at last!!!!
26
30
- If no suitable selection is found, the function should async wait!!!!
27
31
28
32
Args:
@@ -35,17 +39,8 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
35
39
raise NotImplementedError
36
40
37
41
38
- class RoundRobinScheduler (TraceScheduler ):
39
- """
40
- A concurrency-safe scheduling strategy that cycles through active traces
41
- in a round-robin fashion.
42
-
43
- NOTE: we don't need to use asyncio.Lock here as the kickoff_loop ensures the ExpGen is always sequential, instead of parallel.
44
- """
45
-
46
- def __init__ (self , max_trace_num : int ):
47
- self .max_trace_num = max_trace_num
48
- self ._last_selected_leaf_id = - 1
42
+ class BaseScheduler (TraceScheduler ):
43
+ def __init__ (self ):
49
44
self .rec_commit_idx = 0 # the node before rec_idx is already committed.
50
45
self .uncommited_rec_status = defaultdict (int ) # the uncommited record status
51
46
@@ -56,25 +51,241 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
56
51
while True :
57
52
# step 0: Commit the pending selections
58
53
for i in range (self .rec_commit_idx , len (trace .dag_parent )):
59
-
60
- if trace . dag_parent [ i ] == trace .NEW_ROOT :
54
+ parent_of_i = trace . dag_parent [ i ]
55
+ if parent_of_i == trace .NEW_ROOT :
61
56
self .uncommited_rec_status [trace .NEW_ROOT ] -= 1
62
57
else :
63
- for p in trace . dag_parent [ i ] :
58
+ for p in parent_of_i :
64
59
self .uncommited_rec_status [p ] -= 1
65
-
66
60
self .rec_commit_idx = len (trace .hist )
67
61
68
- # step 1: select the parant trace to expand
69
- # Policy: if we have fewer traces than our target, start a new one.
70
- if trace .sub_trace_count + self .uncommited_rec_status [trace .NEW_ROOT ] < self .max_trace_num :
71
- self .uncommited_rec_status [trace .NEW_ROOT ] += 1
72
- return trace .NEW_ROOT
73
-
74
- # Step2: suggest a selection to a not expanding leave
75
- leaves = trace .get_leaves ()
76
- for leaf in leaves :
77
- if self .uncommited_rec_status [leaf ] == 0 :
78
- self .uncommited_rec_status [leaf ] += 1
79
- return (leaf ,)
62
+ parents = self .select (trace )
63
+
64
+ if parents is not None :
65
+ if parents == trace .NEW_ROOT :
66
+ self .uncommited_rec_status [trace .NEW_ROOT ] += 1
67
+ else :
68
+ for p in parents :
69
+ self .uncommited_rec_status [p ] += 1
70
+ return parents
71
+
80
72
await asyncio .sleep (1 )
73
+
74
+ @abstractmethod
75
+ def select (self , trace : DSTrace ) -> tuple [int , ...] | None :
76
+ """Selects the parent nodes for the new experiment, or None if no selection can be made."""
77
+ raise NotImplementedError
78
+
79
+
80
+ class RoundRobinScheduler (BaseScheduler ):
81
+ """
82
+ A concurrency-safe scheduling strategy that cycles through active traces
83
+ in a round-robin fashion.
84
+
85
+ NOTE: we don't need to use asyncio.Lock here as the kickoff_loop ensures the ExpGen is always sequential, instead of parallel.
86
+ """
87
+
88
+ def __init__ (self , max_trace_num : int , * args , ** kwargs ):
89
+ logger .info (f"RoundRobinScheduler: max_trace_num={ max_trace_num } " )
90
+ self .max_trace_num = max_trace_num
91
+ self ._last_selected_leaf_id = - 1
92
+ super ().__init__ ()
93
+
94
+ def select (self , trace : DSTrace ) -> tuple [int , ...] | None :
95
+ """
96
+ Atomically selects the next leaf node from the trace in order.
97
+ If no suitable selection is found, return None.
98
+ """
99
+ # Policy: if we have fewer traces than our target, start a new one.
100
+ if trace .sub_trace_count + self .uncommited_rec_status [trace .NEW_ROOT ] < self .max_trace_num :
101
+ return trace .NEW_ROOT
102
+
103
+ # Step2: suggest a selection to a not expanding leave
104
+ leaves = trace .get_leaves ()
105
+ for leaf in leaves :
106
+ if self .uncommited_rec_status [leaf ] == 0 :
107
+ return (leaf ,)
108
+
109
+ return None
110
+
111
+
112
+ # ======================================================================================
113
+ # Probabilistic Scheduler and its potential functions
114
+ # ======================================================================================
115
+
116
+
117
+ class ProbabilisticScheduler (BaseScheduler ):
118
+ """
119
+ A concurrency-safe scheduling strategy that samples the next trace to expand
120
+ based on a probability distribution derived from a potential function.
121
+ """
122
+
123
+ def __init__ (self , max_trace_num : int , temperature : float = 1.0 , * args , ** kwargs ):
124
+ """
125
+ Args:
126
+ max_trace_num: The target number of parallel traces.
127
+ temperature: Temperature parameter for softmax calculation. Higher values make selection more uniform.
128
+ """
129
+ if max_trace_num <= 0 :
130
+ raise ValueError ("max_trace_num must be positive." )
131
+ if temperature <= 0 :
132
+ raise ValueError ("temperature must be positive." )
133
+
134
+ self .max_trace_num = max_trace_num
135
+ self .temperature = temperature
136
+ super ().__init__ ()
137
+
138
+ def calculate_potential (self , trace : DSTrace , leaf_id : int ) -> float :
139
+ """
140
+ Calculate potential score for a given leaf node.
141
+ This is the base implementation that provides uniform distribution.
142
+
143
+ Args:
144
+ trace: The DSTrace object containing the full experiment history.
145
+ leaf_id: The index of the leaf node to evaluate.
146
+
147
+ Returns:
148
+ float: A potential score. Higher means more likely to be selected.
149
+ """
150
+ return 1.0 # Uniform distribution by default
151
+
152
+ def _softmax_probabilities (self , potentials : list [float ]) -> list [float ]:
153
+ """
154
+ Convert potential scores to probabilities using softmax.
155
+
156
+ Args:
157
+ potentials: List of potential scores.
158
+
159
+ Returns:
160
+ List of probabilities that sum to 1.
161
+ """
162
+ if not potentials :
163
+ return []
164
+
165
+ # Apply temperature scaling
166
+ scaled_potentials = [p / self .temperature for p in potentials ]
167
+
168
+ # Compute softmax
169
+ max_potential = max (scaled_potentials )
170
+ exp_potentials = [math .exp (p - max_potential ) for p in scaled_potentials ]
171
+ sum_exp = sum (exp_potentials )
172
+
173
+ if sum_exp == 0 :
174
+ # If all potentials are very small, return uniform distribution
175
+ return [1.0 / len (potentials )] * len (potentials )
176
+
177
+ return [exp_p / sum_exp for exp_p in exp_potentials ]
178
+
179
+ def select (self , trace : DSTrace ) -> tuple [int , ...] | None :
180
+ """
181
+ Selects the next leaf node based on probabilistic sampling.
182
+ """
183
+ # Step 1: If we have fewer traces than our target, start a new one.
184
+ # This policy prioritizes reaching the desired number of traces.
185
+ if trace .sub_trace_count + self .uncommited_rec_status [trace .NEW_ROOT ] < self .max_trace_num :
186
+ return trace .NEW_ROOT
187
+
188
+ # Step 2: Probabilistically select a leaf to expand.
189
+ leaves = trace .get_leaves ()
190
+ available_leaves = [leaf for leaf in leaves if self .uncommited_rec_status [leaf ] == 0 ]
191
+
192
+ if not available_leaves :
193
+ return None
194
+
195
+ # Calculate potential for each available leaf
196
+ potentials = [self .calculate_potential (trace , leaf ) for leaf in available_leaves ]
197
+
198
+ if any (p < 0 for p in potentials ):
199
+ raise ValueError ("Potential function returned a negative value." )
200
+
201
+ # Convert potentials to probabilities using softmax
202
+ probabilities = self ._softmax_probabilities (potentials )
203
+
204
+ # Select a leaf based on probabilities
205
+ selected_leaf = random .choices (available_leaves , weights = probabilities , k = 1 )[0 ]
206
+
207
+ return (selected_leaf ,)
208
+
209
+
210
+ class TraceLengthScheduler (ProbabilisticScheduler ):
211
+ """
212
+ A scheduler that prefers longer traces (more experiments)
213
+ -- default: prefer to expand the trace that has more experiments (quicker to get the result).
214
+ -- if inverse=True, prefer to expand the trace that has less experiments.
215
+
216
+ """
217
+
218
+ def __init__ (self , max_trace_num : int , temperature : float = 1.0 , inverse : bool = False , * args , ** kwargs ):
219
+ """
220
+ Args:
221
+ max_trace_num: The target number of parallel traces.
222
+ temperature: Temperature parameter for softmax calculation.
223
+ inverse: If True, shorter traces get higher potential.
224
+ """
225
+ logger .info (
226
+ f"TraceLengthScheduler: max_trace_num={ max_trace_num } , temperature={ temperature } , inverse={ inverse } "
227
+ )
228
+ super ().__init__ (max_trace_num , temperature )
229
+ self .inverse = inverse
230
+
231
+ def calculate_potential (self , trace : DSTrace , leaf_id : int ) -> float :
232
+ """
233
+ Calculate potential based on the length of the trace leading to the leaf.
234
+ """
235
+ # Get the path from root to this leaf using existing method
236
+ path = trace .get_parents (leaf_id )
237
+ path_len = len (path )
238
+
239
+ if path_len == 0 :
240
+ return 1.0
241
+
242
+ return 1.0 / path_len if self .inverse else float (path_len )
243
+
244
+
245
+ class SOTABasedScheduler (ProbabilisticScheduler ):
246
+ """
247
+ A scheduler that prefers traces with more SOTA (State of the Art) results.
248
+ """
249
+
250
+ def __init__ (self , max_trace_num : int , temperature : float = 1.0 , inverse : bool = False , * args , ** kwargs ):
251
+ """
252
+ Args:
253
+ max_trace_num: The target number of parallel traces.
254
+ temperature: Temperature parameter for softmax calculation.
255
+ inverse: If True, fewer SOTA results get higher potential.
256
+ """
257
+ logger .info (f"SOTABasedScheduler: max_trace_num={ max_trace_num } , temperature={ temperature } , inverse={ inverse } " )
258
+ super ().__init__ (max_trace_num , temperature )
259
+ self .inverse = inverse
260
+
261
+ def calculate_potential (self , trace : DSTrace , leaf_id : int ) -> float :
262
+ """
263
+ Calculate potential based on the number of SOTA results in the trace.
264
+ """
265
+ # Get the path from root to this leaf
266
+ path = trace .get_parents (leaf_id )
267
+ sota_count = 0
268
+
269
+ for node_id in path :
270
+ # Check if this experiment was successful (decision=True)
271
+ if node_id < len (trace .hist ):
272
+ exp , feedback = trace .hist [node_id ]
273
+ if feedback .decision :
274
+ sota_count += 1
275
+
276
+ if self .inverse :
277
+ # Add 1 to avoid division by zero and give traces with 0 SOTAs the highest potential.
278
+ return 1.0 / (sota_count + 1 )
279
+ return float (sota_count )
280
+
281
+
282
+ class RandomScheduler (ProbabilisticScheduler ):
283
+ """
284
+ A scheduler that selects traces randomly with uniform distribution.
285
+ """
286
+
287
+ def calculate_potential (self , trace : DSTrace , leaf_id : int ) -> float :
288
+ """
289
+ Return random potential for uniform random selection.
290
+ """
291
+ return random .random ()
0 commit comments