Skip to content

Commit 081394b

Browse files
authored
Move project description from setup.py to pyproject.toml (#21)
Also applied ruff
1 parent 80354e0 commit 081394b

File tree

5 files changed

+13
-14
lines changed

5 files changed

+13
-14
lines changed

pyproject.toml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
[tool.poetry]
2-
authors = ["Perplexity AI"]
3-
description = ""
1+
[project]
42
name = "pplx-kernels"
5-
readme = "README.md"
63
version = "0.0.1"
4+
description = "Perplexity CUDA Kernels"
5+
readme = "README.md"
6+
requires-python = ">=3.12"
7+
8+
[build-system]
9+
requires = ["setuptools>=61.0", "wheel", "torch"]
10+
build-backend = "setuptools.build_meta"
711

812
[tool.ruff]
913
line-length = 88

setup.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from setuptools.command.build import build as _build
77
from setuptools.command.build_ext import build_ext
88

9-
VERSION = "0.0.1"
10-
119

1210
def _get_torch_cmake_prefix_path() -> str:
1311
import torch
@@ -88,9 +86,6 @@ def run(self) -> None:
8886
]
8987

9088
setup(
91-
name="pplx-kernels",
92-
version=VERSION,
93-
description="Perplexity Kernels",
9489
packages=find_packages(where="src"),
9590
package_dir={"": "src"},
9691
package_data={
@@ -103,8 +98,6 @@ def run(self) -> None:
10398
},
10499
options={"bdist_wheel": {"py_limited_api": "cp39"}},
105100
zip_safe=False,
106-
install_requires=["torch"],
107-
python_requires=">=3.10",
108101
ext_modules=extensions,
109102
include_package_data=True,
110103
)

src/pplx_kernels/all_to_all.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# pyright: reportCallIssue=false
22

3-
from typing import Any, Callable
3+
from collections.abc import Callable
4+
from typing import Any
45

56
import torch
67

tests/distributed_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import dataclasses
22
import logging
33
import os
4-
from typing import Callable, Concatenate, ParamSpec
4+
from collections.abc import Callable
5+
from typing import Concatenate, ParamSpec
56

67
import pytest
78
import torch

tests/test_all_to_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _do_test_all_to_all(
130130
logger.debug(
131131
" x[%d] -> %s",
132132
token_idx,
133-
list(zip(indices, weights)),
133+
list(zip(indices, weights, strict=False)),
134134
)
135135
for token_idx in range(rd.num_tokens):
136136
logger.debug(" x[%d]=%s", token_idx, _str_1d_tensor(rd.x[token_idx]))

0 commit comments

Comments
 (0)