-
Notifications
You must be signed in to change notification settings - Fork 44
Implement Dominated Novelty Search #664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
Hi @ryanboldi, thanks for sending in this PR! Is it ready for review yet? |
|
Hi @btjanaka! Yep it is ready for review. I am still working on the tutorial notebook but I believe they are independent for the sake of code review. Thanks! |
Awesome I'll try to take a look next week! |
btjanaka
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ryanboldi, thank you for your patience! I've gone through the PR and left some comments. I hope you find them helpful! I've mainly pointed out little things in the implementation and tests. I am happy to accept the PR with these fixes and also with a few more tests. If you have time after that, it would be really helpful if we could get an example of DNS running so that we basically "battle test" this implementation and determine if there are any bugs of performance bottlenecks. Some ideas include an addition to the examples/sphere.py script where we currently dump many of our algorithms, or the tutorial that you mentioned. Thanks again for working on this!
|
Thanks for the detailed feedback! I tried to make the changes you requested. Would appreciate a second review when you get a chance. |
## Description <!-- Provide a brief description of the PR's purpose here. --> Added dtypes for np.asarray in methods like retrieve and index_of across all the archives. Credit to #664 for pointing this out in DNSArchive's retrieve method. ## Status - [x] I have read the guidelines in [CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md) - [x] I have linted and formatted my code with `ruff` and `ty` - [x] I have tested my code by running `pytest` - [x] I have added a description of my change to the changelog in `HISTORY.md` - [x] This PR is ready to go
btjanaka
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ryanboldi, thanks for making these changes! I had a couple questions regarding your PR; could you take a look at those? Also, please make sure to pull before making any new changes, as I made some edits to your code. I think we are pretty close to merging!
|
|
||
| def compute_dns(self, measures: ArrayLike, objectives: ArrayLike) -> np.ndarray: | ||
| """Computes DNS scores for a batch against the current population. | ||
| def _compute_dns(self, measures: ArrayLike, objectives: ArrayLike) -> np.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I notice you switched this to a private method. Are there cases where people would want to compute the DNS score without adding to the archive? For example, maybe someone wants to make a heatmap showing the DNS score across the archive? In such a case, a public API would be useful.
| check_shape(measures, "measures", self.measure_dim, "measure_dim") | ||
| check_finite(measures, "measures") | ||
| return self.index_of(measures[None])[0] | ||
| return int(self.index_of(measures[None])[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why this has to be cast to int? Since you are calling index_of, which casts to int32, this should also end up being int32.
| for name in self._store.field_list: | ||
| combined[name] = ( | ||
| np.concatenate((cur[name], data[name]), axis=0) | ||
| if name in data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should throw an error instead for the case when if name not in data? I don't imagine your code will behave very nicely if you continue to use combined with some entries being longer than others.
| if n_total <= cap: | ||
| survivor_indices = np.arange(n_total) | ||
| else: | ||
| # Take largest `cap` values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be done with:
survivor_indices = np.argsort(dns_scores)[-cap:]Does that achieve what you intend?
| @@ -456,11 +378,23 @@ | |||
|
|
|||
| # Update stats. | |||
| if len(self) > 0: | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any case where adding will result in the archive shrinking in size? If so, this if statement would cause the archive stats to no longer be current.
| check_batch_shape(measures, "measures", self.measure_dim, "measure_dim") | ||
| check_finite(measures, "measures") | ||
|
|
||
| occupied, data = cast( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why this has to be cast?
| if self.empty: | ||
| raise IndexError("No elements in archive.") | ||
|
|
||
| # Deterministic selection: return the first n elites (in storage order). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this have to be deterministic selection?
| ) -> np.ndarray | BatchData | tuple[np.ndarray] | ArchiveDataFrame: | ||
| return self._store.data(fields, return_type) | ||
|
|
||
| def sample_elites(self, n: Int) -> BatchData: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: Add replace parameter?
Description
Add Dominated Novelty Search (https://arxiv.org/abs/2502.00593) as an archive to be used in pyres
TODO
Status
CONTRIBUTING.md
ruffandtypytestHISTORY.md