-
Notifications
You must be signed in to change notification settings - Fork 415
refactor(gry): refactor reward model #636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ruoyuGao
wants to merge
63
commits into
opendilab:main
Choose a base branch
from
ruoyuGao:ruoyugao
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
63 commits
Select commit
Hold shift + click to select a range
c372c07
refactor network and red reward model
ruoyuGao 6718e4a
create reward model utils
ruoyuGao be7039a
polish network and reward model utils, provide test for them
ruoyuGao a4de466
refactor network for two method: learn and forward
ruoyuGao d615c14
Merge branch 'main' into ruoyugao
ruoyuGao 7a8ec6e
refactor rnd
ruoyuGao 55c7be8
refactor gail
ruoyuGao ff60716
fix gail for unit test
ruoyuGao 6b80392
refactor icm
ruoyuGao 25d49b5
fix wrong unit test in test_reward_model_utils
ruoyuGao c081ff0
refactor gcl and pwil
ruoyuGao f1218cd
refactor pdeil
ruoyuGao d9060c2
add hidden_size_list to gail
ruoyuGao 179182a
change gail test for new config
ruoyuGao d067731
refactor trex network
ruoyuGao 29f0d55
fix style and wrong import
ruoyuGao 4ec0bd3
fix style for trex
ruoyuGao 800f090
Merge branch 'main' into ruoyugao
ruoyuGao c64b5c7
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao 660af32
fix unit test for trex onppo
ruoyuGao 1b0d579
Merge branch 'main' into ruoyugao
ruoyuGao b4e81dd
refactor ngu and provide cartpole config file
ruoyuGao eddc80d
change reward entry
ruoyuGao 6e2b867
change trex entry to new entry, combine old trex test to new test
ruoyuGao e25d265
Merge branch 'main' into ruoyugao
ruoyuGao 97634dc
refactor trex config file
ruoyuGao f099cac
refactor trex config file
ruoyuGao 0c48c08
refactor trex config file
ruoyuGao 594d619
add gail to new reward entry
ruoyuGao 58a2bff
remove preferenced based irl entry(used for trex, drex before)
ruoyuGao e9db652
Merge branch 'main' into ruoyugao
ruoyuGao 822d7a4
remove unuse code in gcl
ruoyuGao be03aa9
change clear data from pipeline to RM && add ngu to new entry
ruoyuGao d3ce3e2
remove ngu old entry
ruoyuGao 4c19aa3
fix env pool test bug
ruoyuGao 0cc2149
add drex to new entry
ruoyuGao 5b4e4cc
fix unit test for trex and gail
ruoyuGao ff4de47
fix style
ruoyuGao 9e63ef1
fix style for drex unittest
ruoyuGao ca2e2db
fix drex unittest
ruoyuGao 8716afe
fix bug in minigrid env
ruoyuGao 9036141
add explain for rm utils
ruoyuGao 6b9754a
move RM unittest into one file
ruoyuGao a52a1c0
Merge branch 'main' into ruoyugao
ruoyuGao a5c7989
add drex config
ruoyuGao d631237
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao f42d131
fix ngu wrapper bug in minigrid
ruoyuGao edff260
fix ngu wrapper bug in minigrid
ruoyuGao 6ab66e1
Merge branch 'main' into ruoyugao
ruoyuGao cb0c627
refactor gcl, add it to reward entry
ruoyuGao 016fbb3
refactor gcl config and bash format other config
ruoyuGao cf50148
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao 919c01b
fix bug for test, remove wrong comment
ruoyuGao a1d0b3a
polish code for ngu, drex, base rm and entry
ruoyuGao 0a0af3c
Merge branch 'main' into ruoyugao
ruoyuGao e310b4c
polish code for all rm
ruoyuGao 92dc227
fix style for ngu
ruoyuGao a4f364d
polish comment for config files
ruoyuGao 1f06dec
add gcl unit test
ruoyuGao a547b3b
polish RM
ruoyuGao 97da5c6
fix style for rnd and icm
ruoyuGao 774b2a4
fix style for rnd and icm
ruoyuGao b78e36c
fix style for icm
ruoyuGao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from typing import Union, Tuple, List, Dict | ||
from easydict import EasyDict | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from ding.utils import SequenceType, REWARD_MODEL_REGISTRY | ||
from ding.model import FCEncoder, ConvEncoder | ||
from ding.torch_utils.data_helper import to_tensor | ||
import numpy as np | ||
|
||
|
||
class FeatureNetwork(nn.Module): | ||
|
||
def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType) -> None: | ||
ruoyuGao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
super(FeatureNetwork, self).__init__() | ||
if isinstance(obs_shape, int) or len(obs_shape) == 1: | ||
self.feature = FCEncoder(obs_shape, hidden_size_list) | ||
elif len(obs_shape) == 3: | ||
self.feature = ConvEncoder(obs_shape, hidden_size_list) | ||
else: | ||
raise KeyError( | ||
"not support obs_shape for pre-defined encoder: {}, please customize your own RND model". | ||
format(obs_shape) | ||
) | ||
|
||
def forward(self, obs: torch.Tensor) -> torch.Tensor: | ||
feature_output = self.feature(obs) | ||
return feature_output | ||
|
||
|
||
class RndNetwork(nn.Module): | ||
|
||
def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType) -> None: | ||
super(RndNetwork, self).__init__() | ||
self.target = FeatureNetwork(obs_shape, hidden_size_list) | ||
self.predictor = FeatureNetwork(obs_shape, hidden_size_list) | ||
|
||
for param in self.target.parameters(): | ||
param.requires_grad = False | ||
|
||
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
predict_feature = self.predictor(obs) | ||
with torch.no_grad(): | ||
target_feature = self.target(obs) | ||
return predict_feature, target_feature | ||
|
||
|
||
class RedNetwork(RndNetwork): | ||
|
||
def __init__(self, obs_shape: int, action_shape: int, hidden_size_list: SequenceType) -> None: | ||
# RED network does not support high dimension obs | ||
super().__init__(obs_shape + action_shape, hidden_size_list) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Union, Optional, List, Any, Tuple | ||
from collections.abc import Iterable | ||
|
||
import torch | ||
import torch.optim as optim | ||
import torch.nn.functional as F | ||
|
||
|
||
def concat_state_action_pairs( | ||
data: list, action_size: Optional[int] = None, one_hot: Optional[bool] = False | ||
) -> torch.Tensor: | ||
""" | ||
Overview: | ||
Concatenate state and action pairs from input. | ||
Arguments: | ||
- data (:obj:`List`): List with at least ``obs`` and ``action`` keys. | ||
Returns: | ||
- state_actions_tensor (:obj:`Torch.tensor`): State and action pairs. | ||
""" | ||
states_data = [] | ||
actions_data = [] | ||
#check data(dict) has key obs and action | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 空格 使用 bash format.sh ding 格式化代码 |
||
assert isinstance(data, Iterable) | ||
assert "obs" in data[0] and "action" in data[0] | ||
ruoyuGao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
for item in data: | ||
states_data.append(item['obs'].flatten()) # to allow 3d obs and actions concatenation | ||
if one_hot and action_size: | ||
action = torch.Tensor([int(i == item['action']) for i in range(action_size)]) | ||
ruoyuGao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
actions_data.append(action) | ||
else: | ||
actions_data.append(item['action']) | ||
|
||
states_tensor: torch.Tensor = torch.stack(states_data).float() | ||
actions_tensor: torch.Tensor = torch.stack(actions_data).float() | ||
states_actions_tensor: torch.Tensor = torch.cat([states_tensor, actions_tensor], dim=1) | ||
|
||
return states_actions_tensor |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.