Skip to content

Commit 82cf0a1

Browse files
committed
fix(nyz): fix multi trainer test
1 parent f6808a9 commit 82cf0a1

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

ding/framework/middleware/tests/test_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import pytest
22
import random
3-
import torch
43
import copy
4+
import torch
5+
import treetensor.torch as ttorch
56
from unittest.mock import Mock, patch
67
from ding.data.buffer import DequeBuffer
78
from ding.framework import OnlineRLContext, task
@@ -10,6 +11,7 @@
1011

1112

1213
class MockPolicy(Mock):
14+
_device = 'cpu'
1315
# MockPolicy class for train mode
1416
def forward(self, train_data, **kwargs):
1517
res = {
@@ -19,6 +21,7 @@ def forward(self, train_data, **kwargs):
1921

2022

2123
class MultiStepMockPolicy(Mock):
24+
_device = 'cpu'
2225
# MockPolicy class for multi-step train mode
2326
def forward(self, train_data, **kwargs):
2427
res = [
@@ -34,7 +37,7 @@ def forward(self, train_data, **kwargs):
3437

3538
def get_mock_train_input():
3639
data = {'obs': torch.rand(2, 2), 'next_obs': torch.rand(2, 2), 'reward': random.random(), 'info': {}}
37-
return data
40+
return ttorch.as_tensor(data)
3841

3942

4043
@pytest.mark.unittest

docker/Dockerfile.base

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime as base
33
WORKDIR /ding
44

55
RUN apt update \
6-
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils zip -y \
6+
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils zip unzip -y \
77
&& apt clean \
88
&& rm -rf /var/cache/apt/* \
99
&& sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \
@@ -33,7 +33,7 @@ RUN apt-get update && \
3333
python3.8 python3-pip python3.8-dev
3434

3535
RUN apt update \
36-
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils zip -y \
36+
&& apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils zip unzip -y \
3737
&& apt clean \
3838
&& rm -rf /var/cache/apt/* \
3939
&& sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \

0 commit comments

Comments
 (0)