@@ -34,19 +34,18 @@ class PriorityFunction(ABC):
3434 Each priority_fn,
3535 Args:
3636 item: List[Experience], assume that all experiences in it have the same model_version and use_count
37- kwargs: storage_config.replay_buffer_kwargs (except priority_fn)
37+ priority_fn_args: Dict, the arguments for priority_fn
38+
3839 Returns:
3940 priority: float
4041 put_into_queue: bool, decide whether to put item into queue
42+
4143 Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer.
4244 """
4345
44- def __init__ (self , ** kwargs ):
45- pass
46-
4746 @abstractmethod
48- def __call__ (self , items : List [Experience ]) -> Tuple [float , bool ]:
49- """Calculate the priority of items ."""
47+ def __call__ (self , item : List [Experience ]) -> Tuple [float , bool ]:
48+ """Calculate the priority of item ."""
5049
5150 @classmethod
5251 @abstractmethod
@@ -61,11 +60,11 @@ class LinearDecayPriority(PriorityFunction):
6160 Priority is calculated as `model_version - decay * use_count. The item is always put back into the queue for reuse (as long as `reuse_cooldown_time` is not None).
6261 """
6362
64- def __init__ (self , decay : float = 2.0 , ** kwargs ):
63+ def __init__ (self , decay : float = 2.0 ):
6564 self .decay = decay
6665
67- def __call__ (self , items : List [Experience ]) -> Tuple [float , bool ]:
68- priority = float (items [0 ].info ["model_version" ] - self .decay * items [0 ].info ["use_count" ])
66+ def __call__ (self , item : List [Experience ]) -> Tuple [float , bool ]:
67+ priority = float (item [0 ].info ["model_version" ] - self .decay * item [0 ].info ["use_count" ])
6968 put_into_queue = True
7069 return priority , put_into_queue
7170
@@ -83,17 +82,17 @@ class LinearDecayUseCountControlPriority(PriorityFunction):
8382 Priority is calculated as `model_version - decay * use_count`; if `sigma` is non-zero, priority is further perturbed by random Gaussian noise with standard deviation `sigma`. The item will be put back into the queue only if use count does not exceed `use_count_limit`.
8483 """
8584
86- def __init__ (self , decay : float = 2.0 , use_count_limit : int = 3 , sigma : float = 0.0 , ** kwargs ):
85+ def __init__ (self , decay : float = 2.0 , use_count_limit : int = 3 , sigma : float = 0.0 ):
8786 self .decay = decay
8887 self .use_count_limit = use_count_limit
8988 self .sigma = sigma
9089
91- def __call__ (self , items : List [Experience ]) -> Tuple [float , bool ]:
92- priority = float (items [0 ].info ["model_version" ] - self .decay * items [0 ].info ["use_count" ])
90+ def __call__ (self , item : List [Experience ]) -> Tuple [float , bool ]:
91+ priority = float (item [0 ].info ["model_version" ] - self .decay * item [0 ].info ["use_count" ])
9392 if self .sigma > 0.0 :
9493 priority += float (np .random .randn () * self .sigma )
9594 put_into_queue = (
96- items [0 ].info ["use_count" ] < self .use_count_limit if self .use_count_limit > 0 else True
95+ item [0 ].info ["use_count" ] < self .use_count_limit if self .use_count_limit > 0 else True
9796 )
9897 return priority , put_into_queue
9998
@@ -203,7 +202,8 @@ def __init__(
203202 self .item_count = 0
204203 self .priority_groups = SortedDict () # Maps priority -> deque of items
205204 priority_fn_cls = PRIORITY_FUNC .get (priority_fn )
206- kwargs = priority_fn_cls .default_config ().update (priority_fn_args or {})
205+ kwargs = priority_fn_cls .default_config ()
206+ kwargs .update (priority_fn_args or {})
207207 self .priority_fn = priority_fn_cls (** kwargs )
208208 self .reuse_cooldown_time = reuse_cooldown_time
209209 self ._condition = asyncio .Condition () # For thread-safe operations
0 commit comments