|
17 | 17 |
|
18 | 18 |
|
19 | 19 | class PngWriter: |
20 | | - def __init__(self, outdir, prompt=None): |
| 20 | + def __init__(self, outdir): |
21 | 21 | self.outdir = outdir |
22 | | - self.prompt = prompt |
23 | | - self.filepath = None |
24 | | - self.files_written = [] |
25 | 22 | os.makedirs(outdir, exist_ok=True) |
26 | 23 |
|
27 | | - def write_image(self, image, seed, upscaled=False): |
28 | | - self.filepath = self.unique_filename( |
29 | | - seed, upscaled, self.filepath |
30 | | - ) # will increment name in some sensible way |
31 | | - try: |
32 | | - prompt = f'{self.prompt} -S{seed}' |
33 | | - self.save_image_and_prompt_to_png(image, prompt, self.filepath) |
34 | | - except IOError as e: |
35 | | - print(e) |
36 | | - if not upscaled: |
37 | | - self.files_written.append([self.filepath, seed]) |
38 | | - |
39 | | - def unique_filename(self, seed, upscaled=False, previouspath=None): |
40 | | - revision = 1 |
41 | | - |
42 | | - if previouspath is None: |
43 | | - # sort reverse alphabetically until we find max+1 |
44 | | - dirlist = sorted(os.listdir(self.outdir), reverse=True) |
45 | | - # find the first filename that matches our pattern or return 000000.0.png |
46 | | - filename = next( |
47 | | - (f for f in dirlist if re.match('^(\d+)\..*\.png', f)), |
48 | | - '0000000.0.png', |
49 | | - ) |
50 | | - basecount = int(filename.split('.', 1)[0]) |
51 | | - basecount += 1 |
52 | | - filename = f'{basecount:06}.{seed}.png' |
53 | | - return os.path.join(self.outdir, filename) |
54 | | - |
55 | | - else: |
56 | | - basename = os.path.basename(previouspath) |
57 | | - x = re.match('^(\d+)\..*\.png', basename) |
58 | | - if not x: |
59 | | - return self.unique_filename(seed, upscaled, previouspath) |
60 | | - |
61 | | - basecount = int(x.groups()[0]) |
62 | | - series = 0 |
63 | | - finished = False |
64 | | - while not finished: |
65 | | - series += 1 |
66 | | - filename = f'{basecount:06}.{seed}.png' |
67 | | - path = os.path.join(self.outdir, filename) |
68 | | - finished = not os.path.exists(path) |
69 | | - return os.path.join(self.outdir, filename) |
70 | | - |
71 | | - def save_image_and_prompt_to_png(self, image, prompt, path): |
| 24 | + # gives the next unique prefix in outdir |
| 25 | + def unique_prefix(self): |
| 26 | + # sort reverse alphabetically until we find max+1 |
| 27 | + dirlist = sorted(os.listdir(self.outdir), reverse=True) |
| 28 | + # find the first filename that matches our pattern or return 000000.0.png |
| 29 | + existing_name = next( |
| 30 | + (f for f in dirlist if re.match('^(\d+)\..*\.png', f)), |
| 31 | + '0000000.0.png', |
| 32 | + ) |
| 33 | + basecount = int(existing_name.split('.', 1)[0]) + 1 |
| 34 | + return f'{basecount:06}' |
| 35 | + |
| 36 | + # saves image named _image_ to outdir/name, writing metadata from prompt |
| 37 | + # returns full path of output |
| 38 | + def save_image_and_prompt_to_png(self, image, prompt, name): |
| 39 | + path = os.path.join(self.outdir, name) |
72 | 40 | info = PngImagePlugin.PngInfo() |
73 | 41 | info.add_text('Dream', prompt) |
74 | 42 | image.save(path, 'PNG', pnginfo=info) |
| 43 | + return path |
75 | 44 |
|
| 45 | + # TODO move this to its own helper function; it's not really a method of pngwriter |
76 | 46 | def make_grid(self, image_list, rows=None, cols=None): |
77 | 47 | image_cnt = len(image_list) |
78 | 48 | if None in (rows, cols): |
|
0 commit comments