Skip to content

Commit d6a0c7f

Browse files
authored
Merge pull request #53 from aluamorim/minor_fixes
Minor fixes
2 parents f695c7c + 032543d commit d6a0c7f

File tree

9 files changed

+134
-35
lines changed

9 files changed

+134
-35
lines changed

examples/use_modernised.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pyrevolve as pr
22
import numpy as np
3-
import collections
3+
from collections.abc import Mapping
44

55

66
class Symbol(object):
@@ -30,7 +30,7 @@ def __init__(self, u, m):
3030
self.m = m
3131

3232
def apply(self, t_start, t_end):
33-
print((">"*(t_end-t_start)).rjust(t_end))
33+
# print((">"*(t_end-t_start)).rjust(t_end))
3434
for i in range(t_start, t_end):
3535
u.data = u.data + m.data
3636

@@ -42,7 +42,7 @@ def __init__(self, u, m, v):
4242
self.m = m
4343

4444
def apply(self, t_start, t_end):
45-
print(("<"*(t_end-t_start)).rjust(t_end))
45+
# print(("<"*(t_end-t_start)).rjust(t_end))
4646
for i in range(t_end, t_start, -1):
4747
v.data = v.data + m.data
4848

@@ -55,7 +55,7 @@ def __init__(self, symbols):
5555
stores only a reference to the symbols that are passed into it.
5656
The symbols must be passed as a mapping symbolname->symbolobject."""
5757

58-
if(isinstance(symbols, collections.Mapping)):
58+
if(isinstance(symbols, Mapping)):
5959
self.symbols = symbols
6060
else:
6161
raise Exception("Symbols must be a Mapping, for example a \

pyrevolve/profiling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_dict(self):
5252
results = {}
5353
for s_n, s_dict in self.timings.items():
5454
for a_n, a_time in s_dict.items():
55-
results['%s_%s_timing' % (s_n, a_n)] = a_time
55+
results['%s_%s_timing' % (s_n, a_n)] = "{:.2f}".format(a_time)
5656

5757
for s_n, s_dict in self.counts.items():
5858
for a_n, a_time in s_dict.items():

pyrevolve/pyrevolve.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,10 @@ def addByteStorage(self, compression_params):
171171
def makespan(self):
172172
return 0
173173

174+
@property
175+
def ratio(self):
176+
return 0
177+
174178
def apply_forward(self):
175179
"""Executes only the forward computation while storing checkpoints,
176180
then returns."""
@@ -268,6 +272,11 @@ def load_checkpoint(self, st_idx=0):
268272
def remove_checkpoint(self, st_idx=0):
269273
return NotImplemented
270274

275+
def storage_ckps(self, k=0):
276+
"""Returns a list of all checkpoint keys stored at the k-th
277+
storage level"""
278+
return NotImplemented
279+
271280

272281
class SingleLevelRevolver(BaseRevolver):
273282
"""
@@ -345,6 +354,16 @@ def __init__(
345354
self.compression_params = compression_params
346355
self.addNumpyStorage(compression_params)
347356

357+
@property
358+
def ratio(self):
359+
return self.scheduler.ratio
360+
361+
def storage_ckps(self, k=0):
362+
"""Returns a list of all checkpoint keys stored at the k-th
363+
storage level"""
364+
# single level always uses first storage object on storage_list
365+
return self.scheduler.storage(0)
366+
348367

349368
class MultiLevelRevolver(BaseRevolver):
350369
"""
@@ -389,7 +408,6 @@ def __init__(
389408
checkpoint: checkpoint object
390409
fwd_operator: forward operator
391410
rev_operator: backward operator
392-
n_checkpoints: number of checkpoints
393411
n_timesteps: number of timesteps
394412
timings: timings
395413
profiler: profiler
@@ -443,6 +461,15 @@ def reload_scheduler(self, uf=1, ub=1, up=1):
443461
def makespan(self):
444462
return self.scheduler.makespan
445463

464+
@property
465+
def ratio(self):
466+
return self.scheduler.ratio
467+
468+
def storage_ckps(self, k=0):
469+
"""Returns a list of all checkpoint keys stored at the k-th
470+
storage level"""
471+
return self.scheduler.storage(k)
472+
446473
def save_checkpoint(self, st_idx=0):
447474
data_pointers = self.checkpoint.get_data(self.scheduler.capo)
448475
self.storage_list[st_idx].push(data_pointers)

pyrevolve/schedulers/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ def __repr__(self):
102102
dict(
103103
{
104104
"type": self.type_names[self.type],
105-
"capo": self.capo,
106-
"old_capo": self.old_capo,
105+
"from": self.old_capo,
106+
"to": self.capo,
107107
"ckp": self.ckp,
108108
}
109109
)
@@ -143,3 +143,7 @@ def old_capo(self):
143143
@property
144144
def cp_pointer(self):
145145
return 0
146+
147+
@property
148+
def oplist(self):
149+
return None

pyrevolve/schedulers/crevolve.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,30 @@ class CRevolve(Scheduler):
3535

3636
def __init__(self, n_checkpoints, n_timesteps):
3737
super().__init__(n_checkpoints, n_timesteps)
38-
self.revolve = cr.CRevolve(n_checkpoints, n_timesteps, None)
3938
self.__revstart_action = None
39+
self.revolve = cr.CRevolve(n_checkpoints, n_timesteps, None)
40+
self.__oplist = []
41+
self.__stored_ckps = []
42+
self.__ratio = self.__calc_ratio()
43+
self.revolve = cr.CRevolve(n_checkpoints, n_timesteps, None)
44+
45+
def __calc_ratio(self):
46+
fcomp = 0
47+
ca = self.next()
48+
self.__oplist.append(ca)
49+
while ca.type != Action.TERMINATE:
50+
if (ca.type == Action.ADVANCE) or (ca.type == Action.LASTFW):
51+
st = ca.old_capo
52+
end = ca.capo
53+
fcomp += (end-st)
54+
ca = self.next()
55+
self.__oplist.append(ca)
56+
57+
return (fcomp/self.n_timesteps)
58+
59+
@property
60+
def oplist(self):
61+
return self.__oplist
4062

4163
def next(self):
4264
if self.__revstart_action is None:
@@ -53,6 +75,8 @@ def next(self):
5375
old_capo=self.old_capo,
5476
ckp=self.cp_pointer,
5577
)
78+
if ca.type is Action.TAKESHOT:
79+
self.__stored_ckps.append(ca.capo)
5680
else:
5781
ca = self.__revstart_action
5882
self.__revstart_action = None
@@ -69,3 +93,12 @@ def old_capo(self):
6993
@property
7094
def cp_pointer(self):
7195
return self.revolve.check
96+
97+
@property
98+
def ratio(self):
99+
return self.__ratio
100+
101+
def storage(self, k):
102+
"""Returns a list of all checkpoint keys stored at the k-th
103+
storage level. For CRevolve, k is always 0 """
104+
return self.__stored_ckps

pyrevolve/schedulers/hrevolve.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,10 @@ def __init__(self, n_checkpoints, n_timesteps, architecture=None, uf=1, ub=1, up
662662
def resetSequence(self):
663663
self.__copindex = 0
664664

665+
@property
666+
def oplist(self):
667+
return self.__oplist
668+
665669
def next(self):
666670
if self.__copindex >= self.n_ops:
667671
ha = HAction(action_type=Action.TERMINATE)
@@ -729,6 +733,14 @@ def __check_for_cpdel_condition(self):
729733

730734
return ret
731735

736+
def storage(self, k):
737+
"""Returns a list of all checkpoint keys stored at the k-th
738+
storage level"""
739+
if k < self.architecture.nblevels:
740+
return self.__sequence.storage[k]
741+
else:
742+
return None
743+
732744
@property
733745
def capo(self):
734746
return self.__capo
@@ -745,6 +757,20 @@ def cp_pointer(self):
745757
def makespan(self):
746758
return self.__sequence.makespan
747759

760+
@property
761+
def ratio(self):
762+
# compute recomputation ratio:
763+
fcomp = 0
764+
for op in self.__oplist:
765+
if op.type == "Forwards":
766+
st = op.index[0]
767+
end = op.index[1]
768+
fcomp += (end-st)
769+
elif op.type == "Forward":
770+
fcomp += 1
771+
772+
return (fcomp/self.n_timesteps)
773+
748774
def hrevolve_aux(self, l, K, cmem, hoptp=None, hopt=None):
749775
"""
750776
This function is a copy of the orginal HRevolve_Aux

pyrevolve/storage.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self, size_ckp, n_ckp, dtype, profiler, wd=0, rd=0, name="storage")
5454
self.name = name
5555
self.__maxsize = self.size_ckp * n_ckp
5656
self.__current_size = 0
57+
self.__itemsize = np.dtype(self.dtype).itemsize
5758

5859
# stack interface controls
5960
self.__stack_ptr = -1
@@ -97,12 +98,17 @@ def isEmpty(self):
9798

9899
@property
99100
def maxsize(self):
100-
"""Returns the maximum storage size in bytes"""
101+
"""Returns the maximum storage size in words of 'dtype'"""
101102
return self.__maxsize
102103

104+
@property
105+
def maxsize_in_bytes(self):
106+
"""Returns the maximum storage size in bytes"""
107+
return self.__maxsize*self.__itemsize
108+
103109
@property
104110
def size(self):
105-
"""Returns the current storage size in bytes"""
111+
"""Returns the current storage size in words of 'dtype'"""
106112
return self.__current_size
107113

108114
@property
@@ -216,12 +222,13 @@ def save(self, key, data_pointers):
216222
else:
217223
ckpfile = self.datFileName + (".k%d" % (key))
218224
self.checkFilesDir()
219-
slot = open(ckpfile, "bw+")
225+
slot = open(ckpfile, "bw")
220226

221227
for ptr in data_pointers:
222228
assert ptr.strides[-1] == ptr.itemsize
223229
with self.profiler.get_timer("storage", "flatten"):
224230
data = ptr.ravel()
231+
data = data.astype(self.dtype)
225232
data.tofile(slot)
226233
slot.flush()
227234
self.__current_size += self.size_ckp
@@ -238,7 +245,7 @@ def load(self, key, locations):
238245
else:
239246
ckpfile = self.datFileName + (".k%d" % (key))
240247
self.checkFilesDir()
241-
slot = open(ckpfile, "br+")
248+
slot = open(ckpfile, "br")
242249

243250
offset = 0
244251
for shape, ptr in zip(self.shapes[key], locations):
@@ -283,6 +290,7 @@ def save(self, key, data_pointers):
283290
data = ptr.ravel()
284291
with self.profiler.get_timer("storage", "copy_save"):
285292
np.copyto(slot[offset:(len(data) + offset)], data)
293+
286294
offset += len(data)
287295
shapes.append(ptr.shape)
288296
self.shapes[key] = shapes

pyrevolve/tools.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def stop(self):
4848
"""
4949
Stop capturing the stream data and save the text in `capturedtext`.
5050
"""
51+
5152
# Print the escape character to make the readOutput method stop:
5253
self.origstream.write(self.escape_char)
5354
# Flush the stream to make sure all our data goes in before

0 commit comments

Comments
 (0)