1818from functools import partial
1919import logging
2020import math
21+ import os
2122import platform
2223import unittest
2324from unittest import mock
@@ -539,6 +540,7 @@ def test_backend_serialization_deserialization(self):
539540 def test_persistent_cache_enable_xla_caches (self ):
540541 if jtu .jaxlib_version () <= (0 , 4 , 35 ):
541542 self .skipTest ("Test requires AutotuneCacheMode bindings" )
543+ s = os .sep
542544 with config .compilation_cache_dir ("jax-cache" ):
543545 with config .persistent_cache_enable_xla_caches ("none" ):
544546 compile_options = compiler .get_compile_options (
@@ -552,15 +554,15 @@ def test_persistent_cache_enable_xla_caches(self):
552554 compile_options = compiler .get_compile_options (
553555 num_replicas = 1 , num_partitions = 1
554556 )
555- self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_kernel_cache_file , "jax-cache/ xla_gpu_kernel_cache_file" )
557+ self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_kernel_cache_file , f "jax-cache{ s } xla_gpu_kernel_cache_file" )
556558 self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_enable_llvm_module_compilation_parallelism , True )
557- self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_per_fusion_autotune_cache_dir , "jax-cache/ xla_gpu_per_fusion_autotune_cache_dir" )
559+ self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_per_fusion_autotune_cache_dir , f "jax-cache{ s } xla_gpu_per_fusion_autotune_cache_dir" )
558560 self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_experimental_autotune_cache_mode , xc .AutotuneCacheMode .UPDATE )
559561 with config .persistent_cache_enable_xla_caches ("xla_gpu_kernel_cache_file" ):
560562 compile_options = compiler .get_compile_options (
561563 num_replicas = 1 , num_partitions = 1
562564 )
563- self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_kernel_cache_file , "jax-cache/ xla_gpu_kernel_cache_file" )
565+ self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_kernel_cache_file , f "jax-cache{ s } xla_gpu_kernel_cache_file" )
564566 self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_enable_llvm_module_compilation_parallelism , True )
565567 self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_per_fusion_autotune_cache_dir , "" )
566568 self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_experimental_autotune_cache_mode , xc .AutotuneCacheMode .UPDATE )
@@ -570,7 +572,7 @@ def test_persistent_cache_enable_xla_caches(self):
570572 )
571573 self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_kernel_cache_file , "" )
572574 self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_enable_llvm_module_compilation_parallelism , False )
573- self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_per_fusion_autotune_cache_dir , "jax-cache/ xla_gpu_per_fusion_autotune_cache_dir" )
575+ self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_per_fusion_autotune_cache_dir , f "jax-cache{ s } xla_gpu_per_fusion_autotune_cache_dir" )
574576 self .assertEqual (compile_options .executable_build_options .debug_options .xla_gpu_experimental_autotune_cache_mode , xc .AutotuneCacheMode .UPDATE )
575577
576578@jtu .with_config (
0 commit comments