Skip to content

Commit e0d1edc

Browse files
authored
enable compilation cache with grok (#395)
1 parent ec1f2b2 commit e0d1edc

File tree

3 files changed

+12
-19
lines changed

3 files changed

+12
-19
lines changed

example/grok/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
Loading and running the Grok-1 open-weights model by [Grok-1](https://github.com/xai-org/grok-1)
44

5+
The Grok-1 model running needs at least 8-tile GPU device.
6+
57
## 1. Install intel-extension-for-openxla
68

79
Please got the [main page](https://github.com/intel/intel-extension-for-openxla/blob/main/README.md#build-and-install), and follow the instructions to build and install intel-extension-for-openxla.

example/grok/inference.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
import json
55
import os
6+
import jax
67

78
from model import LanguageModelConfig, TransformerConfig
89
from runners import InferenceRunner, ModelRunner, sample_from_model
@@ -12,9 +13,15 @@ def main(args):
1213
num_warmup = args.num_warmup
1314
input_tokens = args.input_tokens
1415
max_new_tokens = args.max_new_tokens
16+
compilcation_cache = args.compilcation_cache
1517
input_len = int(input_tokens)
1618

1719
current_path = str(os.path.dirname(__file__))
20+
21+
if compilcation_cache:
22+
COMPILATION_CACHE_PATH = current_path +"/compilcation_cache/"
23+
jax.config.update("jax_compilation_cache_dir", COMPILATION_CACHE_PATH)
24+
1825
CKPT_PATH = current_path +"/checkpoints/"
1926
with open(current_path + "/prompt.json") as f:
2027
content = f.read()
@@ -86,5 +93,6 @@ def main(args):
8693
parser.add_argument("--num-warmup", default=1, type=int, help="num warmup")
8794
parser.add_argument("--input-tokens",default="32",choices=["32", "64", "128", "256", "512", "1024", "2016", "2017", "2048", "4096", "8192"],type=str,help="input tokens length if needed from prompt.json")
8895
parser.add_argument("--max-new-tokens", default=32, type=int, help="output max new tokens")
96+
parser.add_argument("--compilcation-cache", default=False, type=bool, help="compilcation cache")
8997
args = parser.parse_args()
9098
main(args)

third_party/openxla.patch

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1974,7 +1974,7 @@ index 0aa610fc9..3c4b34ace 100644
19741974
MatrixIsColumnMajor(instr, gemm_backend_config));
19751975

19761976
diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc
1977-
index d0c20aa1c..86ac26006 100644
1977+
index d0c20aa1c..98ce30ebe 100644
19781978
--- a/xla/service/gpu/gpu_compiler.cc
19791979
+++ b/xla/service/gpu/gpu_compiler.cc
19801980
@@ -209,6 +209,7 @@ limitations under the License.
@@ -2115,24 +2115,7 @@ index d0c20aa1c..86ac26006 100644
21152115
}
21162116

21172117
HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
2118-
@@ -2148,6 +2175,7 @@ HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
2119-
2120-
absl::StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::Export(
2121-
Executable* executable) const {
2122-
+#if 0
2123-
auto* gpu_executable = tensorflow::down_cast<GpuExecutable*>(executable);
2124-
if (!gpu_executable) return Internal("GpuExecutable is null");
2125-
2126-
@@ -2155,6 +2183,8 @@ absl::StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::Export(
2127-
&gpu_executable->module(), gpu_executable->buffer_assignment(),
2128-
gpu_executable->text(), gpu_executable->binary(),
2129-
gpu_executable->dnn_compiled_graphs());
2130-
+#endif
2131-
+ LOG(FATAL) << "GpuCompiler::Export is not implemented";
2132-
}
2133-
2134-
absl::Status GpuCompiler::RunPostSchedulingPipelines(
2135-
@@ -2215,13 +2245,18 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines(
2118+
@@ -2215,13 +2242,18 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines(
21362119
auto driver_version = se::gpu::GpuDriver::GetDriverVersion();
21372120
#if GOOGLE_CUDA
21382121
constexpr int toolkit_version = CUDA_VERSION;

0 commit comments

Comments
 (0)