|
339 | 339 | " \"`test` that `a==b` and are same type\"\n", |
340 | 340 | " test_eq(a,b)\n", |
341 | 341 | " test_eq(type(a),type(b))\n", |
342 | | - " if isinstance(a,(list,tuple)): test_eq(map(type,a),map(type,b))" |
| 342 | + " if isinstance(a,(list,tuple)): test_eq(map(type,a),map(type,b)) # type of each element\n", |
| 343 | + " if isinstance(a, (torch.Tensor, pd.Series, np.ndarray)): test_eq(a.dtype, b.dtype) # dtypes of both tensors" |
343 | 344 | ] |
344 | 345 | }, |
345 | 346 | { |
|
354 | 355 | "test_fail(lambda: test_eq_type(1,1.))\n", |
355 | 356 | "test_eq_type([1,1],[1,1])\n", |
356 | 357 | "test_fail(lambda: test_eq_type([1,1],(1,1)))\n", |
357 | | - "test_fail(lambda: test_eq_type([1,1],[1,1.]))" |
| 358 | + "test_fail(lambda: test_eq_type([1,1],[1,1.]))\n", |
| 359 | + "test_fail(lambda: test_eq_type(torch.zeros(10), torch.zeros(10, dtype=torch.float64)))\n", |
| 360 | + "test_fail(lambda: test_eq_type(torch.zeros(10), np.zeros(10)))\n", |
| 361 | + "test_eq_type(torch.zeros(3), torch.Tensor([0, 0, 0]))" |
358 | 362 | ] |
359 | 363 | }, |
360 | 364 | { |
|
0 commit comments