11{
22 lib ,
3+ stdenv ,
34 config ,
45 buildPythonPackage ,
56 fetchFromGitHub ,
7+
8+ # patches
69 replaceVars ,
710 addDriverRunpath ,
8- cudaSupport ? config . cudaSupport ,
9- rocmSupport ? config . rocmSupport ,
1011 cudaPackages ,
12+ llvmPackages ,
1113 ocl-icd ,
1214 rocmPackages ,
13- stdenv ,
1415
1516 # build-system
1617 setuptools ,
2122 unicorn ,
2223
2324 # tests
25+ pytestCheckHook ,
26+ writableTmpDirAsHomeHook ,
2427 blobfile ,
2528 bottle ,
29+ capstone ,
2630 clang ,
2731 hexdump ,
2832 hypothesis ,
33+ jax ,
2934 librosa ,
3035 networkx ,
3136 numpy ,
3237 onnx ,
38+ onnxruntime ,
3339 pillow ,
3440 pytest-xdist ,
35- pytestCheckHook ,
3641 safetensors ,
3742 sentencepiece ,
3843 tiktoken ,
3944 torch ,
4045 tqdm ,
4146 transformers ,
4247
48+ # passthru
4349 tinygrad ,
50+
51+ cudaSupport ? config . cudaSupport ,
52+ rocmSupport ? config . rocmSupport ,
4453} :
4554
4655buildPythonPackage rec {
4756 pname = "tinygrad" ;
48- version = "0.10.0 " ;
57+ version = "0.10.2 " ;
4958 pyproject = true ;
5059
5160 src = fetchFromGitHub {
5261 owner = "tinygrad" ;
5362 repo = "tinygrad" ;
5463 tag = "v${ version } " ;
55- hash = "sha256-IIyTb3jDUSEP2IXK6DLsI15E5N34Utt7xv86aTHpXf8 =" ;
64+ hash = "sha256-BXQMacp6QjlgsVwhp2pxEZkRylZfKQhqIh92/0dPlfg =" ;
5665 } ;
5766
5867 patches = [
@@ -68,14 +77,30 @@ buildPythonPackage rec {
6877
6978 postPatch =
7079 # Patch `clang` directly in the source file
80+ # Use the unwrapped variant to enable the "native" features currently unavailable in the sandbox
7181 ''
72- substituteInPlace tinygrad/runtime/ops_clang.py \
73- --replace-fail "'clang'" "'${ lib . getExe clang } '"
82+ substituteInPlace tinygrad/runtime/ops_cpu.py \
83+ --replace-fail "getenv(\"CC\", 'clang')" "'${ lib . getExe llvmPackages . clang-unwrapped } '"
84+ ''
85+ + ''
86+ substituteInPlace tinygrad/runtime/autogen/libc.py \
87+ --replace-fail "ctypes.util.find_library('c')" "'${ stdenv . cc . libc } /lib/libc.so.6'"
88+ ''
89+ + ''
90+ substituteInPlace tinygrad/runtime/support/llvm.py \
91+ --replace-fail "ctypes.util.find_library('LLVM')" "'${ lib . getLib llvmPackages . llvm } /lib/libLLVM.so'"
7492 ''
7593 + lib . optionalString stdenv . hostPlatform . isLinux ''
7694 substituteInPlace tinygrad/runtime/autogen/opencl.py \
7795 --replace-fail "ctypes.util.find_library('OpenCL')" "'${ ocl-icd } /lib/libOpenCL.so'"
7896 ''
97+ # test/test_tensor.py imports the PTX variable from the cuda_compiler.py file.
98+ # This import leads to loading the libnvrtc.so library that is not substituted when cudaSupport = false.
99+ # -> As a fix, we hardcode this variable to False
100+ + lib . optionalString ( ! cudaSupport ) ''
101+ substituteInPlace test/test_tensor.py \
102+ --replace-fail "from tinygrad.runtime.support.compiler_cuda import PTX" "PTX = False"
103+ ''
79104 # `cuda_fp16.h` and co. are needed at runtime to compile kernels
80105 + lib . optionalString cudaSupport ''
81106 substituteInPlace tinygrad/runtime/support/compiler_cuda.py \
@@ -114,18 +139,23 @@ buildPythonPackage rec {
114139 ] ;
115140
116141 nativeCheckInputs = [
142+ pytestCheckHook
143+ writableTmpDirAsHomeHook
144+
117145 blobfile
118146 bottle
147+ capstone
119148 clang
120149 hexdump
121150 hypothesis
151+ jax
122152 librosa
123153 networkx
124154 numpy
125155 onnx
156+ onnxruntime
126157 pillow
127158 pytest-xdist
128- pytestCheckHook
129159 safetensors
130160 sentencepiece
131161 tiktoken
@@ -134,15 +164,15 @@ buildPythonPackage rec {
134164 transformers
135165 ] ++ networkx . optional-dependencies . extra ;
136166
137- preCheck = ''
138- export HOME=$(mktemp -d)
139- '' ;
140-
141167 disabledTests =
142168 [
143- # Fixed in https://github.com/tinygrad/tinygrad/pull/7792
144- # TODO: re-enable at next release
145- "test_kernel_cache_in_action"
169+ # RuntimeError: Attempting to relocate against an undefined symbol 'fmaxf'
170+ "test_backward_sum_acc_dtype"
171+ "test_failure_27"
172+
173+ # Flaky:
174+ # AssertionError: 2.1376906810000946 not less than 2.0
175+ "test_recursive_pad"
146176
147177 # Require internet access
148178 "test_benchmark_openpilot_model"
@@ -178,10 +208,6 @@ buildPythonPackage rec {
178208 "test_vgg7"
179209 ]
180210 ++ lib . optionals ( stdenv . hostPlatform . system == "aarch64-linux" ) [
181- # Fixed in https://github.com/tinygrad/tinygrad/pull/7796
182- # TODO: re-enable at next release
183- "test_interpolate_bilinear"
184-
185211 # Fail with AssertionError
186212 "test_casts_from"
187213 "test_casts_to"
@@ -209,7 +235,9 @@ buildPythonPackage rec {
209235 changelog = "https://github.com/tinygrad/tinygrad/releases/tag/v${ version } " ;
210236 license = lib . licenses . mit ;
211237 maintainers = with lib . maintainers ; [ GaetanLepage ] ;
212- # Tests segfault on darwin
213- badPlatforms = [ lib . systems . inspect . patterns . isDarwin ] ;
238+ badPlatforms = [
239+ # Tests segfault on darwin
240+ lib . systems . inspect . patterns . isDarwin
241+ ] ;
214242 } ;
215243}
0 commit comments