Skip to content

Commit ca25b27

Browse files
authored
Merge branch 'main' into p2p-rpc
2 parents adab7fb + 9d183d9 commit ca25b27

File tree

11 files changed

+13
-8
lines changed

11 files changed

+13
-8
lines changed

ding/framework/middleware/tests/test_trainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import pytest
22
import random
3-
import torch
43
import copy
4+
import torch
5+
import treetensor.torch as ttorch
56
from unittest.mock import Mock, patch
67
from ding.data.buffer import DequeBuffer
78
from ding.framework import OnlineRLContext, task
@@ -10,6 +11,8 @@
1011

1112

1213
class MockPolicy(Mock):
14+
_device = 'cpu'
15+
1316
# MockPolicy class for train mode
1417
def forward(self, train_data, **kwargs):
1518
res = {
@@ -19,6 +22,8 @@ def forward(self, train_data, **kwargs):
1922

2023

2124
class MultiStepMockPolicy(Mock):
25+
_device = 'cpu'
26+
2227
# MockPolicy class for multi-step train mode
2328
def forward(self, train_data, **kwargs):
2429
res = [
@@ -34,7 +39,7 @@ def forward(self, train_data, **kwargs):
3439

3540
def get_mock_train_input():
3641
data = {'obs': torch.rand(2, 2), 'next_obs': torch.rand(2, 2), 'reward': random.random(), 'info': {}}
37-
return data
42+
return ttorch.as_tensor(data)
3843

3944

4045
@pytest.mark.unittest

ding/torch_utils/loss/contrastive_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
7979
x_n = x.view(-1, self._encode_shape)
8080
y_n = y.view(-1, self._encode_shape)
8181

82-
# Use inner product to obtain postive samples.
82+
# Use inner product to obtain positive samples.
8383
# [N, x_heads, encode_dim] * [N, encode_dim, y_heads] -> [N, x_heads, y_heads]
8484
u_pos = torch.matmul(x, y.permute(0, 2, 1)).unsqueeze(2)
8585
# Use outer product to obtain all sample permutations.
@@ -92,7 +92,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
9292
u_neg = (n_mask * u_all) - (10. * (1 - n_mask))
9393
u_neg = u_neg.view(N, N * x_heads, y_heads).unsqueeze(dim=1).expand(-1, x_heads, -1, -1)
9494

95-
# Concatenate postive and negative samples and apply log softmax.
95+
# Concatenate positive and negative samples and apply log softmax.
9696
pred_lgt = torch.cat([u_pos, u_neg], dim=2)
9797
pred_log = F.log_softmax(pred_lgt * self._temperature, dim=2)
9898

dizoo/beergame/__init__.py

Whitespace-only changes.

dizoo/procgen/__init__.py

Whitespace-only changes.

dizoo/rocket/__init__.py

Whitespace-only changes.

dizoo/rocket/config/__init__.py

Whitespace-only changes.

dizoo/rocket/entry/__init__.py

Whitespace-only changes.

dizoo/sokoban/__init__.py

Whitespace-only changes.

docker/Dockerfile.base

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime as base
33
WORKDIR /ding
44

55
RUN apt update \
6-
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils -y \
6+
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils zip unzip -y \
77
&& apt clean \
88
&& rm -rf /var/cache/apt/* \
99
&& sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \
@@ -33,7 +33,7 @@ RUN apt-get update && \
3333
python3.8 python3-pip python3.8-dev
3434

3535
RUN apt update \
36-
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils -y \
36+
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils zip unzip -y \
3737
&& apt clean \
3838
&& rm -rf /var/cache/apt/* \
3939
&& sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \

docker/Dockerfile.env

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,8 @@ WORKDIR /ding
9191
RUN mkdir tempfile \
9292
&& cd tempfile \
9393
&& wget https://github.com/rlworkgroup/metaworld/archive/refs/heads/master.zip -O metaworld_master.zip \
94-
&& apt-get install unzip \
9594
&& unzip metaworld_master.zip \
96-
&& python3 -m pip install --no-cache-dir ./metaworld-master/ \
95+
&& python3 -m pip install --no-cache-dir ./Metaworld-master/ \
9796
&& cd .. \
9897
&& rm -rf tempfile
9998

0 commit comments

Comments
 (0)