|
| 1 | +from tensorflow.keras.models import load_model |
| 2 | +from tensorflow.keras.utils import get_custom_objects |
| 3 | +import argparse |
| 4 | +import tensorflow as tf |
| 5 | +import numpy as np |
| 6 | +from tensorflow.keras.models import Model |
| 7 | +from tensorflow.keras.layers import Input, Cropping1D, add, Conv1D, GlobalAvgPool1D, Dense, Add, Concatenate, Lambda, Flatten |
| 8 | +import time |
| 9 | +import os |
| 10 | +import argparse |
| 11 | + |
| 12 | +def parse_args(): |
| 13 | + parser = argparse.ArgumentParser(description="Reformat chrombpnet h5 file") |
| 14 | + parser.add_argument("-cnb", "--chrombpnet_nb", type=str, required=True, help="Path to chrombpnet no bias model") |
| 15 | + parser.add_argument("-bm", "--bias_model_scaled", type=str, required=True, help="Path to scaled bias model") |
| 16 | + parser.add_argument("-o", "--output_dir", type=str, required=True, help="Path to output dir") |
| 17 | + args = parser.parse_args() |
| 18 | + return args |
| 19 | + |
| 20 | + |
| 21 | +args = parse_args() |
| 22 | + |
| 23 | +args_chrombpnet_nb=args.chrombpnet_nb |
| 24 | +args_bias=args.bias_model_scaled |
| 25 | +args_output_dir=args.output_dir |
| 26 | + |
| 27 | +def chrombpnet_model(bias_model, bpnet_model_wo_bias): |
| 28 | + inp = Input(shape=(2114, 4),name='sequence') |
| 29 | + bias_output=bias_model(inp) |
| 30 | + bpnet_model_wo_bias_new=Model(inputs=bpnet_model_wo_bias.inputs,outputs=bpnet_model_wo_bias.outputs, name="model_wo_bias") |
| 31 | + output_wo_bias=bpnet_model_wo_bias_new(inp) |
| 32 | + |
| 33 | + profile_out = Add(name="logits_profile_predictions")([output_wo_bias[0],bias_output[0]]) |
| 34 | + concat_counts = Concatenate(axis=-1)([output_wo_bias[1], bias_output[1]]) |
| 35 | + count_out = Lambda(lambda x: tf.math.reduce_logsumexp(x, axis=-1, keepdims=True), |
| 36 | + name="logcount_predictions")(concat_counts) |
| 37 | + model=Model(inputs=[inp],outputs=[profile_out, count_out]) |
| 38 | + return model |
| 39 | + |
| 40 | +def main(args_chrombpnet_nb, args_bias, args_output_dir): |
| 41 | + |
| 42 | + custom_objects={"tf":tf} |
| 43 | + get_custom_objects().update(custom_objects) |
| 44 | + |
| 45 | + chrombpnet_nb=load_model(args_chrombpnet_nb,compile=False) |
| 46 | + bias_model=load_model(args_bias,compile=False) |
| 47 | + |
| 48 | + newp = args_output_dir+"/chrombpnet_recompiled.h5" |
| 49 | + new_chrom = chrombpnet_model(bias_model, chrombpnet_nb) |
| 50 | + new_chrom.save(newp) |
| 51 | + newp = args_output_dir+"/chrombpnet_recompiled" |
| 52 | + new_chrom.save(newp) |
| 53 | + |
| 54 | +if __name__ == '__main__': |
| 55 | + |
| 56 | + main(args_chrombpnet_nb, args_bias, args_output_dir) |
| 57 | + |
| 58 | + |
| 59 | + |
| 60 | + |
0 commit comments