Skip to content

Commit 8040918

Browse files
committed
setup.py: add compile flags for bf16 and fp8.
1 parent 9c8c42a commit 8040918

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

csrc/permute.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ std::tuple<torch::Tensor, torch::Tensor, std::vector<Tensor>> moe_permute_topK_o
523523

524524
break;
525525
}
526-
// #ifdef ENABLE_BF16
526+
#ifdef ENABLE_BF16
527527
case at::ScalarType::BFloat16:
528528
{
529529
using dType = cutlass::bfloat16_t;
@@ -545,8 +545,8 @@ std::tuple<torch::Tensor, torch::Tensor, std::vector<Tensor>> moe_permute_topK_o
545545

546546
break;
547547
}
548-
// #endif
549-
// #ifdef ENABLE_FP8
548+
#endif
549+
#ifdef ENABLE_FP8
550550
case at::ScalarType::Float8_e5m2:
551551
{
552552
using dType = cutlass::float_e5m2_t;
@@ -589,7 +589,7 @@ std::tuple<torch::Tensor, torch::Tensor, std::vector<Tensor>> moe_permute_topK_o
589589

590590
break;
591591
}
592-
// #endif
592+
#endif
593593
default:
594594
throw std::runtime_error("Wrong activation tensor type.");
595595
}
@@ -670,7 +670,7 @@ torch::Tensor moe_recover_topK_op(
670670

671671
break;
672672
}
673-
// #ifdef ENABLE_BF16
673+
#ifdef ENABLE_BF16
674674
case at::ScalarType::BFloat16:
675675
{
676676
using dType = cutlass::bfloat16_t;
@@ -692,8 +692,8 @@ torch::Tensor moe_recover_topK_op(
692692

693693
break;
694694
}
695-
// #endif
696-
// #ifdef ENABLE_FP8
695+
#endif
696+
#ifdef ENABLE_FP8
697697
case at::ScalarType::Float8_e5m2:
698698
{
699699
using dType = cutlass::float_e5m2_t;
@@ -736,7 +736,7 @@ torch::Tensor moe_recover_topK_op(
736736

737737
break;
738738
}
739-
// #endif
739+
#endif
740740
default:
741741
throw std::runtime_error("Wrong activation tensor type.");
742742
}
@@ -819,7 +819,7 @@ std::tuple<torch::Tensor, torch::Tensor> moe_recover_topK_bwd_op(
819819

820820
break;
821821
}
822-
// #ifdef ENABLE_BF16
822+
#ifdef ENABLE_BF16
823823
case at::ScalarType::BFloat16:
824824
{
825825
using dType = cutlass::bfloat16_t;
@@ -844,8 +844,8 @@ std::tuple<torch::Tensor, torch::Tensor> moe_recover_topK_bwd_op(
844844

845845
break;
846846
}
847-
// #endif
848-
// #ifdef ENABLE_FP8
847+
#endif
848+
#ifdef ENABLE_FP8
849849
case at::ScalarType::Float8_e5m2:
850850
{
851851
using dType = cutlass::float_e5m2_t;
@@ -894,7 +894,7 @@ std::tuple<torch::Tensor, torch::Tensor> moe_recover_topK_bwd_op(
894894

895895
break;
896896
}
897-
// #endif
897+
#endif
898898
default:
899899
throw std::runtime_error("Wrong activation tensor type.");
900900
}

setup.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,18 @@
55
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
66

77

8-
if os.environ.get("TORCH_CUDA_ARCH_LIST"):
8+
# Supported NVIDIA GPU architectures.
9+
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
10+
11+
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
12+
# e.g. "9.0" or "7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX". Here,
13+
# the "9.0+PTX" option asks the
14+
# compiler to additionally include PTX code that can be runtime-compiled
15+
# and executed on the 8.6 or newer architectures. While the PTX code will
16+
# not give the best performance on the newer architectures, it provides
17+
# forward compatibility.
18+
env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
19+
if env_arch_list:
920
# Let PyTorch builder to choose device to target for.
1021
device_capability = ""
1122
else:
@@ -16,6 +27,8 @@
1627

1728
nvcc_flags = [
1829
"-std=c++17", # NOTE: CUTLASS requires c++17
30+
"-DENABLE_BF16", # Enable BF16 for cuda_version >= 11
31+
# "-DENABLE_FP8", # Enable FP8 for cuda_version >= 11.8
1932
]
2033

2134
if device_capability:

0 commit comments

Comments
 (0)