Skip to content

Commit 889b460

Browse files
committed
clean ups
1 parent cf686f1 commit 889b460

File tree

3 files changed

+0
-10
lines changed

3 files changed

+0
-10
lines changed

pufferlib/extensions/puffernet.pyx

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ def puf_relu(cnp.ndarray input, cnp.ndarray output, int size):
5151
def puf_gelu(cnp.ndarray input, cnp.ndarray output, int size):
5252
_gelu(<float*> input.data, <float*> output.data, size)
5353

54-
def puf_gelu(cnp.ndarray input, cnp.ndarray output, int size):
55-
_gelu(<float*> input.data, <float*> output.data, size)
56-
5754
def puf_sigmoid(float x):
5855
return _sigmoid(x)
5956

tests/test_puffernet.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,6 @@ def test_puffernet_convolution_3d_layer(batch_size=4096, in_width=9, in_height=5
126126
torch_conv.bias.data = bias_torch
127127
output_torch = torch_conv(input_torch).detach()
128128
assert_near(output_puffer, output_torch.numpy())
129-
130-
131-
132129

133130
def test_puffernet_lstm(batch_size=16, input_size=128, hidden_size=128):
134131
input_np = make_dummy_data(batch_size, input_size, seed=42)
@@ -260,7 +257,6 @@ def test_nmmo3(batch_size=1, input_size=512, hidden_size=512):
260257
if __name__ == '__main__':
261258
test_puffernet_relu()
262259
test_puffernet_gelu()
263-
test_puffernet_gelu()
264260
test_puffernet_sigmoid()
265261
test_puffernet_linear_layer()
266262
test_puffernet_convolution_layer()

tests/test_puffernet_linearlstm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ def assert_near(a, b, tolerance=1e-4):
2525
assert np.all(np.abs(a - b) < tolerance), f"Value mismatch exceeds tolerance {tolerance}"
2626

2727
def test_tetris_puffernet(model_path='puffer_tetris_weights.bin'):
28-
"""
29-
Compare the PyTorch Tetris policy with the PufferNet C++ implementation layer by layer.
30-
"""
3128
# Load the environment to get parameters
3229
env = env_creator('puffer_tetris')()
3330

0 commit comments

Comments
 (0)