|
1 | | -import numpy as np |
2 | | -import io,operator,sys,os,re,mimetypes,itertools,shutil,pickle,tempfile,subprocess |
3 | | -import itertools,random,inspect,functools,math,bz2,typing,numbers,warnings,threading |
4 | | -import json,urllib.request |
| 1 | +import sys,os,re,shutil,typing,itertools,operator,functools,math,warnings,functools,inspect,io |
5 | 2 |
|
| 3 | +from operator import itemgetter,attrgetter |
6 | 4 | from warnings import warn |
7 | | -from dataclasses import dataclass |
| 5 | +from typing import Iterable,Generator,Sequence,Iterator |
8 | 6 | from functools import partial,reduce |
9 | | -from threading import Thread |
10 | | -from time import sleep |
11 | | -from copy import copy |
12 | | -from contextlib import redirect_stdout,contextmanager |
13 | | -from collections.abc import Iterable,Iterator,Generator,Collection,Sequence |
14 | | -from types import SimpleNamespace |
15 | 7 | from pathlib import Path |
16 | | -from collections import defaultdict,Counter |
17 | | -from operator import itemgetter,attrgetter |
18 | | -from uuid import uuid4 |
19 | | -from urllib.request import HTTPError |
20 | | - |
21 | | -# External modules |
22 | | -from numpy import array,ndarray |
23 | | -from pdb import set_trace |
24 | | - |
25 | | -#Optional modules |
26 | | -try: import matplotlib.pyplot as plt |
27 | | -except: pass |
28 | 8 |
|
29 | 9 | try: |
30 | 10 | from types import WrapperDescriptorType,MethodWrapperType,MethodDescriptorType |
31 | 11 | except ImportError: |
32 | 12 | WrapperDescriptorType = type(object.__init__) |
33 | 13 | MethodWrapperType = type(object().__str__) |
34 | 14 | MethodDescriptorType = type(str.join) |
35 | | -from types import BuiltinFunctionType,BuiltinMethodType,MethodType,FunctionType |
| 15 | +from types import BuiltinFunctionType,BuiltinMethodType,MethodType,FunctionType,SimpleNamespace |
36 | 16 |
|
37 | 17 | NoneType = type(None) |
38 | 18 | string_classes = (str,bytes) |
@@ -60,17 +40,49 @@ def noops(self, x=None, *args, **kwargs): |
60 | 40 | "Do nothing (method)" |
61 | 41 | return x |
62 | 42 |
|
63 | | -def one_is_instance(a, b, t): return isinstance(a,t) or isinstance(b,t) |
| 43 | +def any_is_instance(t, *args): return any(isinstance(a,t) for a in args) |
| 44 | + |
| 45 | +def isinstance_str(x, cls_name): |
| 46 | + "Like `isinstance`, except takes a type name instead of a type" |
| 47 | + return cls_name in [t.__name__ for t in type(x).__mro__] |
| 48 | + |
| 49 | +def array_equal(a,b): |
| 50 | + if hasattr(a, '__array__'): a = a.__array__() |
| 51 | + if hasattr(b, '__array__'): b = b.__array__() |
| 52 | + return (a==b).all() |
64 | 53 |
|
65 | 54 | def equals(a,b): |
66 | 55 | "Compares `a` and `b` for equality; supports sublists, tensors and arrays too" |
67 | 56 | if (a is None) ^ (b is None): return False |
68 | | - if one_is_instance(a,b,type): return a==b |
| 57 | + if any_is_instance(type,a,b): return a==b |
69 | 58 | if hasattr(a, '__array_eq__'): return a.__array_eq__(b) |
70 | 59 | if hasattr(b, '__array_eq__'): return b.__array_eq__(a) |
71 | | - cmp = (np.array_equal if one_is_instance(a, b, ndarray ) else |
72 | | - operator.eq if one_is_instance(a, b, (str,dict,set)) else |
73 | | - all_equal if is_iter(a) or is_iter(b) else |
| 60 | + cmp = (array_equal if isinstance_str(a, 'ndarray') or isinstance_str(b, 'ndarray') else |
| 61 | + operator.eq if any_is_instance((str,dict,set), a, b) else |
| 62 | + all_equal if is_iter(a) or is_iter(b) else |
74 | 63 | operator.eq) |
75 | 64 | return cmp(a,b) |
76 | 65 |
|
| 66 | +def ipython_shell(): |
| 67 | + "Same as `get_ipython` but returns `False` if not in IPython" |
| 68 | + try: return get_ipython() |
| 69 | + except NameError: return False |
| 70 | + |
| 71 | +def in_ipython(): |
| 72 | + "Check if code is running in some kind of IPython environment" |
| 73 | + return bool(ipython_shell()) |
| 74 | + |
| 75 | +def in_colab(): |
| 76 | + "Check if the code is running in Google Colaboratory" |
| 77 | + return 'google.colab' in sys.modules |
| 78 | + |
| 79 | +def in_jupyter(): |
| 80 | + "Check if the code is running in a jupyter notebook" |
| 81 | + if not in_ipython(): return False |
| 82 | + return ipython_shell().__class__.__name__ == 'ZMQInteractiveShell' |
| 83 | + |
| 84 | +def in_notebook(): |
| 85 | + "Check if the code is running in a jupyter notebook" |
| 86 | + return in_colab() or in_jupyter() |
| 87 | + |
| 88 | +IN_IPYTHON,IN_JUPYTER,IN_COLAB,IN_NOTEBOOK = in_ipython(),in_jupyter(),in_colab(),in_notebook() |
0 commit comments