1+ option (USE_CUDA "Support NVIDIA CUDA" OFF )
2+ option (BUILD_TEST "Build tests" OFF )
3+ option (BUILD_TESTING "third party tests" OFF )
4+
5+ cmake_minimum_required (VERSION 3.28)
6+
7+ include (CMakeDependentOption)
8+ cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF )
9+ project (infini_train VERSION 0.3.0 LANGUAGES CXX)
10+
11+ set (CMAKE_CXX_STANDARD 20)
12+ set (CMAKE_CXX_STANDARD_REQUIRED ON )
13+ set (CMAKE_CXX_EXTENSIONS OFF )
14+
15+ # Generate compile_commands.json
16+ set (CMAKE_EXPORT_COMPILE_COMMANDS ON )
17+
18+ # Add gflags
19+ add_subdirectory (third_party/gflags)
20+ include_directories (${gflags_SOURCE_DIR} /include )
21+
22+ set (WITH_GFLAGS OFF CACHE BOOL "Disable glog finding system gflags" FORCE)
23+ set (WITH_GTEST OFF CACHE BOOL "Disable glog finding system gtest" FORCE)
24+
25+ # Add glog
26+ add_subdirectory (third_party/glog)
27+ include_directories (${glog_SOURCE_DIR} /src)
28+
29+ # Add eigen
30+ find_package (OpenMP REQUIRED)
31+ add_subdirectory (third_party/eigen)
32+ include_directories (${PROJECT_SOURCE_DIR} /third_party/eigen)
33+
34+ include_directories (${PROJECT_SOURCE_DIR} )
35+ file (GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR} /infini_train/src/*.cc)
36+ list (FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*" )
37+
38+ file (GLOB_RECURSE CPU_KERNELS ${PROJECT_SOURCE_DIR} /infini_train/src/kernels/cpu/*.cc)
39+ add_library (infini_train_cpu_kernels STATIC ${CPU_KERNELS} )
40+ target_link_libraries (infini_train_cpu_kernels glog Eigen3::Eigen OpenMP::OpenMP_CXX)
41+
42+ if (USE_CUDA)
43+ add_compile_definitions (USE_CUDA=1)
44+ enable_language (CUDA)
45+ include (FindCUDAToolkit)
46+
47+ # enable CUDA-related compilation options
48+ set (CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr" )
49+ file (GLOB_RECURSE CUDA_KERNELS ${PROJECT_SOURCE_DIR} /infini_train/src/*.cu)
50+ add_library (infini_train_cuda_kernels STATIC ${CUDA_KERNELS} )
51+ set_target_properties (infini_train_cuda_kernels PROPERTIES CUDA_ARCHITECTURES "75;80" )
52+ target_link_libraries (infini_train_cuda_kernels glog CUDA::cudart CUDA::cublas)
53+
54+ add_library (infini_train STATIC ${SRC} )
55+ target_link_libraries (infini_train glog gflags "-Wl,--whole-archive" infini_train_cpu_kernels infini_train_cuda_kernels "-Wl,--no-whole-archive" )
56+ else ()
57+ add_library (infini_train STATIC ${SRC} )
58+ target_link_libraries (infini_train glog gflags "-Wl,--whole-archive" infini_train_cpu_kernels "-Wl,--no-whole-archive" )
59+ endif ()
60+
61+ if (BUILD_TEST)
62+ set (BUILD_GMOCK
63+ OFF
64+ CACHE BOOL "Do not build gmock" FORCE)
65+ set (INSTALL_GTEST
66+ OFF
67+ CACHE BOOL "Do not install gtest" FORCE)
68+ add_subdirectory (third_party/googletest)
69+ include_directories (third_party/googletest/googletest/include )
70+ endif ()
71+
72+ add_library (example_gpt2 STATIC
73+ example/common/tiny_shakespeare_dataset.cc
74+ example/common/tokenizer.cc
75+ example/gpt2/net.cc
76+ )
77+ target_link_libraries (example_gpt2 infini_train)
78+
79+ function (build_test files )
80+ # Non-recursive glob for skip failed tests
81+ file (GLOB TEST_SOURCES ${files} )
82+ foreach (testsourcefile ${TEST_SOURCES} )
83+ get_filename_component (testname ${testsourcefile} NAME_WE )
84+ add_executable (${testname} ${testsourcefile} )
85+ target_link_libraries (${testname} infini_train example_gpt2 GTest::gtest_main)
86+ add_test (NAME ${testname} COMMAND ${testname} )
87+ endforeach (testsourcefile ${TEST_SOURCES} )
88+ endfunction ()
89+
90+ if (BUILD_TEST)
91+ add_compile_definitions (BUILD_TEST=1)
92+ enable_testing ()
93+ if (BUILD_TEST_CORE)
94+ build_test(test /autograd/test_elementwise.cc)
95+ build_test(test /kernels/test_matmul.cc)
96+ build_test(test /kernels/test_dispatcher.cc)
97+ build_test(test /tensor/test_tensor.cc)
98+ build_test(test /optimizer/test_adam.cc)
99+ build_test(test /example/test_gpt2.cc)
100+ if (USE_CUDA)
101+ build_test(test /kernels/test_matmul_cuda.cc)
102+ build_test(test /optimizer/test_adam_cuda.cc)
103+ endif ()
104+ endif ()
105+ endif ()
0 commit comments