|
339 | 339 | " \"`test` that `a==b` and are same type\"\n",
|
340 | 340 | " test_eq(a,b)\n",
|
341 | 341 | " test_eq(type(a),type(b))\n",
|
342 |
| - " if isinstance(a,(list,tuple)): test_eq(map(type,a),map(type,b))" |
| 342 | + " if isinstance(a,(list,tuple)): test_eq(map(type,a),map(type,b)) # type of each element\n", |
| 343 | + " if isinstance(a, (torch.Tensor, pd.Series, np.ndarray)): test_eq(a.dtype, b.dtype) # dtypes of both tensors" |
343 | 344 | ]
|
344 | 345 | },
|
345 | 346 | {
|
|
354 | 355 | "test_fail(lambda: test_eq_type(1,1.))\n",
|
355 | 356 | "test_eq_type([1,1],[1,1])\n",
|
356 | 357 | "test_fail(lambda: test_eq_type([1,1],(1,1)))\n",
|
357 |
| - "test_fail(lambda: test_eq_type([1,1],[1,1.]))" |
| 358 | + "test_fail(lambda: test_eq_type([1,1],[1,1.]))\n", |
| 359 | + "test_fail(lambda: test_eq_type(torch.zeros(10), torch.zeros(10, dtype=torch.float64)))\n", |
| 360 | + "test_fail(lambda: test_eq_type(torch.zeros(10), np.zeros(10)))\n", |
| 361 | + "test_eq_type(torch.zeros(3), torch.Tensor([0, 0, 0]))" |
358 | 362 | ]
|
359 | 363 | },
|
360 | 364 | {
|
|
0 commit comments