Skip to content

Commit 2594172

Browse files
committed
format change
Signed-off-by: gdeng <gdeng@nvidia.com>
1 parent 489fd48 commit 2594172

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

emerging_optimizers/triton_kernels/syrk.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import triton
1818
import triton.language as tl
1919

20+
2021
try:
2122
from triton.tools.tensor_descriptor import TensorDescriptor
2223
except ImportError:
@@ -322,9 +323,9 @@ def tsyrk_ex(
322323
assert a.is_contiguous() or a.T.is_contiguous(), "invalid input tensor layout. a or a.T must be contiguous."
323324

324325
N, K = a.shape
325-
assert (c is None and beta == 0.0) or (
326-
c is not None and c.shape == (N, N)
327-
), "if c is provided, c must be of shape (N, N)"
326+
assert (c is None and beta == 0.0) or (c is not None and c.shape == (N, N)), (
327+
"if c is provided, c must be of shape (N, N)"
328+
)
328329
assert c is None or c.is_contiguous() or c.T.is_contiguous(), "if c is provided, c or c.T must be contiguous"
329330

330331
d = torch.empty((N, N), device=a.device, dtype=a.dtype)

tests/test_triton_kernels.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from absl.testing import absltest, parameterized
1615
import torch
16+
from absl.testing import absltest, parameterized
1717

18-
from emerging_optimizers import triton_kernels
19-
from emerging_optimizers import utils
18+
from emerging_optimizers import triton_kernels, utils
2019

2120

2221
class TritonKernelsTest(parameterized.TestCase):
23-
2422
@parameterized.product(
2523
({"n": 5, "k": 7, "rtol": 1e-5}, {"n": 17, "k": 23, "rtol": 1e-3}, {"n": 127, "k": 255, "rtol": 2e-2}),
2624
({"trans": False},),
@@ -52,7 +50,6 @@ def test_ssyrk_tf32_not_far_from_matmul(self, n: int, k: int, trans: bool, rtol:
5250

5351

5452
class TritonKernelsIntegerInputTest(parameterized.TestCase):
55-
5653
@parameterized.product(
5754
({"n": 5, "k": 7}, {"n": 17, "k": 23}, {"n": 127, "k": 255}),
5855
({"trans": False},),

0 commit comments

Comments
 (0)