Skip to content

Commit a611de0

Browse files
authored
Removing numpy requirement from all files in examples/pytorch/domain_templates (#19947)
1 parent 812ffde commit a611de0

File tree

3 files changed

+13
-14
lines changed

3 files changed

+13
-14
lines changed

examples/pytorch/domain_templates/generative_adversarial_net.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
2020
"""
2121

22+
import math
2223
from argparse import ArgumentParser, Namespace
2324

24-
import numpy as np
2525
import torch
2626
import torch.nn as nn
2727
import torch.nn.functional as F
@@ -59,7 +59,7 @@ def block(in_feat, out_feat, normalize=True):
5959
*block(128, 256),
6060
*block(256, 512),
6161
*block(512, 1024),
62-
nn.Linear(1024, int(np.prod(img_shape))),
62+
nn.Linear(1024, int(math.prod(img_shape))),
6363
nn.Tanh(),
6464
)
6565

@@ -80,7 +80,7 @@ def __init__(self, img_shape):
8080
super().__init__()
8181

8282
self.model = nn.Sequential(
83-
nn.Linear(int(np.prod(img_shape)), 512),
83+
nn.Linear(int(math.prod(img_shape)), 512),
8484
nn.LeakyReLU(0.2, inplace=True),
8585
nn.Linear(512, 256),
8686
nn.LeakyReLU(0.2, inplace=True),

examples/pytorch/domain_templates/reinforce_learn_Qnet.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@
3333
"""
3434

3535
import argparse
36+
import random
3637
from collections import OrderedDict, deque, namedtuple
3738
from typing import Iterator, List, Tuple
3839

3940
import gym
40-
import numpy as np
4141
import torch
4242
import torch.nn as nn
4343
import torch.optim as optim
@@ -103,15 +103,15 @@ def append(self, experience: Experience) -> None:
103103
self.buffer.append(experience)
104104

105105
def sample(self, batch_size: int) -> Tuple:
106-
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
106+
indices = random.sample(range(len(self.buffer)), batch_size)
107107
states, actions, rewards, dones, next_states = zip(*(self.buffer[idx] for idx in indices))
108108

109109
return (
110-
np.array(states),
111-
np.array(actions),
112-
np.array(rewards, dtype=np.float32),
113-
np.array(dones, dtype=np.bool),
114-
np.array(next_states),
110+
torch.tensor(states),
111+
torch.tensor(actions),
112+
torch.tensor(rewards, dtype=torch.float32),
113+
torch.tensor(dones, dtype=torch.bool),
114+
torch.tensor(next_states),
115115
)
116116

117117

@@ -175,7 +175,7 @@ def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
175175
action
176176
177177
"""
178-
if np.random.random() < epsilon:
178+
if random.random() < epsilon:
179179
action = self.env.action_space.sample()
180180
else:
181181
state = torch.tensor([self.state])

examples/pytorch/domain_templates/semantic_segmentation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import random
1717
from argparse import ArgumentParser, Namespace
1818

19-
import numpy as np
2019
import torch
2120
import torch.nn.functional as F
2221
import torchvision.transforms as transforms
@@ -107,11 +106,11 @@ def __len__(self):
107106
def __getitem__(self, idx):
108107
img = Image.open(self.img_list[idx])
109108
img = img.resize(self.img_size)
110-
img = np.array(img)
109+
img = torch.tensor(img)
111110

112111
mask = Image.open(self.mask_list[idx]).convert("L")
113112
mask = mask.resize(self.img_size)
114-
mask = np.array(mask)
113+
mask = torch.tensor(mask)
115114
mask = self.encode_segmap(mask)
116115

117116
if self.transform:

0 commit comments

Comments
 (0)