From e11d73d5a2e6c78faafbd9e6edd897e034569cea Mon Sep 17 00:00:00 2001 From: mya012 <96360296+mya012@users.noreply.github.com> Date: Mon, 20 Oct 2025 11:27:00 +0800 Subject: [PATCH 1/6] Add files via upload --- datasets/txtfiles/HDR/train.txt | 226 +++++++++++++++++++++++++++++++ datasets/txtfiles/HDR/train1.txt | 226 +++++++++++++++++++++++++++++++ 2 files changed, 452 insertions(+) create mode 100644 datasets/txtfiles/HDR/train.txt create mode 100644 datasets/txtfiles/HDR/train1.txt diff --git a/datasets/txtfiles/HDR/train.txt b/datasets/txtfiles/HDR/train.txt new file mode 100644 index 0000000..42d0241 --- /dev/null +++ b/datasets/txtfiles/HDR/train.txt @@ -0,0 +1,226 @@ +./LDR/7U6A3450.CR2 ./HDR/train/7U6A3449_hdr +./LDR/7U6A3453.CR2 ./HDR/train/7U6A3452_hdr +./LDR/7U6A3456.CR2 ./HDR/train/7U6A3455_hdr +./LDR/7U6A3459.CR2 ./HDR/train/7U6A3458_hdr +./LDR/7U6A3462.CR2 ./HDR/train/7U6A3461_hdr +./LDR/7U6A3474.CR2 ./HDR/train/7U6A3473_hdr +./LDR/7U6A3477.CR2 ./HDR/train/7U6A3476_hdr +./LDR/7U6A3480.CR2 ./HDR/train/7U6A3479_hdr +./LDR/7U6A3495.CR2 ./HDR/train/7U6A3494_hdr +./LDR/7U6A3513.CR2 ./HDR/train/7U6A3512_hdr +./LDR/7U6A8060.CR2 ./HDR/train/7U6A8059_hdr +./LDR/7U6A8069.CR2 ./HDR/train/7U6A8068_hdr +./LDR/7U6A8075.CR2 ./HDR/train/7U6A8074_hdr +./LDR/7U6A8081.CR2 ./HDR/train/7U6A8080_hdr +./LDR/7U6A8084.CR2 ./HDR/train/7U6A8083_hdr +./LDR/7U6A8087.CR2 ./HDR/train/7U6A8086_hdr +./LDR/7U6A8096.CR2 ./HDR/train/7U6A8095_hdr +./LDR/7U6A8108.CR2 ./HDR/train/7U6A8107_hdr +./LDR/7U6A8111.CR2 ./HDR/train/7U6A8110_hdr +./LDR/7U6A8126.CR2 ./HDR/train/7U6A8125_hdr +./LDR/7U6A8129.CR2 ./HDR/train/7U6A8128_hdr +./LDR/7U6A8135.CR2 ./HDR/train/7U6A8134_hdr +./LDR/7U6A8141.CR2 ./HDR/train/7U6A8140_hdr +./LDR/7U6A8144.CR2 ./HDR/train/7U6A8143_hdr +./LDR/7U6A8159.CR2 ./HDR/train/7U6A8158_hdr +./LDR/7U6A8168.CR2 ./HDR/train/7U6A8167_hdr +./LDR/7U6A8174.CR2 ./HDR/train/7U6A8173_hdr +./LDR/7U6A8177.CR2 ./HDR/train/7U6A8176_hdr +./LDR/7U6A8180.CR2 ./HDR/train/7U6A8179_hdr +./LDR/7U6A8189.CR2 ./HDR/train/7U6A8188_hdr +./LDR/7U6A8192.CR2 ./HDR/train/7U6A8191_hdr +./LDR/7U6A8195.CR2 ./HDR/train/7U6A8194_hdr +./LDR/7U6A8201.CR2 ./HDR/train/7U6A8200_hdr +./LDR/7U6A8207.CR2 ./HDR/train/7U6A8206_hdr +./LDR/7U6A8219.CR2 ./HDR/train/7U6A8218_hdr +./LDR/7U6A8222.CR2 ./HDR/train/7U6A8221_hdr +./LDR/7U6A8225.CR2 ./HDR/train/7U6A8224_hdr +./LDR/7U6A8231.CR2 ./HDR/train/7U6A8230_hdr +./LDR/7U6A8234.CR2 ./HDR/train/7U6A8233_hdr +./LDR/7U6A8240.CR2 ./HDR/train/7U6A8239_hdr +./LDR/7U6A8249.CR2 ./HDR/train/7U6A8248_hdr +./LDR/7U6A8252.CR2 ./HDR/train/7U6A8251_hdr +./LDR/7U6A8258.CR2 ./HDR/train/7U6A8257_hdr +./LDR/7U6A8264.CR2 ./HDR/train/7U6A8263_hdr +./LDR/7U6A8270.CR2 ./HDR/train/7U6A8269_hdr +./LDR/7U6A8273.CR2 ./HDR/train/7U6A8272_hdr +./LDR/7U6A8276.CR2 ./HDR/train/7U6A8275_hdr +./LDR/7U6A8279.CR2 ./HDR/train/7U6A8278_hdr +./LDR/7U6A8285.CR2 ./HDR/train/7U6A8284_hdr +./LDR/7U6A8294.CR2 ./HDR/train/7U6A8293_hdr +./LDR/7U6A8300.CR2 ./HDR/train/7U6A8299_hdr +./LDR/7U6A8303.CR2 ./HDR/train/7U6A8302_hdr +./LDR/7U6A8315.CR2 ./HDR/train/7U6A8314_hdr +./LDR/7U6A8318.CR2 ./HDR/train/7U6A8317_hdr +./LDR/7U6A8321.CR2 ./HDR/train/7U6A8320_hdr +./LDR/7U6A8330.CR2 ./HDR/train/7U6A8329_hdr +./LDR/7U6A8336.CR2 ./HDR/train/7U6A8335_hdr +./LDR/7U6A8339.CR2 ./HDR/train/7U6A8338_hdr +./LDR/7U6A8348.CR2 ./HDR/train/7U6A8347_hdr +./LDR/7U6A8351.CR2 ./HDR/train/7U6A8350_hdr +./LDR/7U6A8354.CR2 ./HDR/train/7U6A8353_hdr +./LDR/7U6A8357.CR2 ./HDR/train/7U6A8356_hdr +./LDR/7U6A8360.CR2 ./HDR/train/7U6A8359_hdr +./LDR/7U6A8366.CR2 ./HDR/train/7U6A8365_hdr +./LDR/7U6A8372.CR2 ./HDR/train/7U6A8371_hdr +./LDR/7U6A8381.CR2 ./HDR/train/7U6A8380_hdr +./LDR/7U6A8384.CR2 ./HDR/train/7U6A8383_hdr +./LDR/7U6A8395.CR2 ./HDR/train/7U6A8394_hdr +./LDR/7U6A8398.CR2 ./HDR/train/7U6A8397_hdr +./LDR/7U6A8404.CR2 ./HDR/train/7U6A8403_hdr +./LDR/7U6A8407.CR2 ./HDR/train/7U6A8406_hdr +./LDR/7U6A8413.CR2 ./HDR/train/7U6A8412_hdr +./LDR/7U6A8416.CR2 ./HDR/train/7U6A8415_hdr +./LDR/7U6A8419.CR2 ./HDR/train/7U6A8418_hdr +./LDR/7U6A8422.CR2 ./HDR/train/7U6A8421_hdr +./LDR/7U6A8428.CR2 ./HDR/train/7U6A8427_hdr +./LDR/7U6A8434.CR2 ./HDR/train/7U6A8433_hdr +./LDR/7U6A8440.CR2 ./HDR/train/7U6A8439_hdr +./LDR/7U6A8443.CR2 ./HDR/train/7U6A8442_hdr +./LDR/7U6A8446.CR2 ./HDR/train/7U6A8445_hdr +./LDR/7U6A8452.CR2 ./HDR/train/7U6A8451_hdr +./LDR/7U6A8455.CR2 ./HDR/train/7U6A8454_hdr +./LDR/7U6A8458.CR2 ./HDR/train/7U6A8457_hdr +./LDR/7U6A8461.CR2 ./HDR/train/7U6A8460_hdr +./LDR/7U6A8473.CR2 ./HDR/train/7U6A8472_hdr +./LDR/7U6A8479.CR2 ./HDR/train/7U6A8478_hdr +./LDR/7U6A8482.CR2 ./HDR/train/7U6A8481_hdr +./LDR/7U6A8491.CR2 ./HDR/train/7U6A8490_hdr +./LDR/7U6A8494.CR2 ./HDR/train/7U6A8493_hdr +./LDR/7U6A8497.CR2 ./HDR/train/7U6A8496_hdr +./LDR/7U6A8500.CR2 ./HDR/train/7U6A8499_hdr +./LDR/7U6A8503.CR2 ./HDR/train/7U6A8502_hdr +./LDR/7U6A8506.CR2 ./HDR/train/7U6A8505_hdr +./LDR/7U6A8509.CR2 ./HDR/train/7U6A8508_hdr +./LDR/7U6A8512.CR2 ./HDR/train/7U6A8511_hdr +./LDR/7U6A8518.CR2 ./HDR/train/7U6A8517_hdr +./LDR/7U6A8521.CR2 ./HDR/train/7U6A8520_hdr +./LDR/7U6A8530.CR2 ./HDR/train/7U6A8529_hdr +./LDR/7U6A8539.CR2 ./HDR/train/7U6A8538_hdr +./LDR/7U6A8542.CR2 ./HDR/train/7U6A8541_hdr +./LDR/7U6A8545.CR2 ./HDR/train/7U6A8544_hdr +./LDR/7U6A8548.CR2 ./HDR/train/7U6A8547_hdr +./LDR/7U6A8551.CR2 ./HDR/train/7U6A8550_hdr +./LDR/7U6A8554.CR2 ./HDR/train/7U6A8553_hdr +./LDR/7U6A8560.CR2 ./HDR/train/7U6A8559_hdr +./LDR/7U6A8575.CR2 ./HDR/train/7U6A8574_hdr +./LDR/7U6A8578.CR2 ./HDR/train/7U6A8577_hdr +./LDR/7U6A8596.CR2 ./HDR/train/7U6A8595_hdr +./LDR/7U6A8599.CR2 ./HDR/train/7U6A8598_hdr +./LDR/7U6A8602.CR2 ./HDR/train/7U6A8601_hdr +./LDR/7U6A8605.CR2 ./HDR/train/7U6A8604_hdr +./LDR/7U6A8609.CR2 ./HDR/train/7U6A8608_hdr +./LDR/7U6A8633.CR2 ./HDR/train/7U6A8632_hdr +./LDR/7U6A8636.CR2 ./HDR/train/7U6A8635_hdr +./LDR/7U6A8642.CR2 ./HDR/train/7U6A8641_hdr +./LDR/7U6A8645.CR2 ./HDR/train/7U6A8644_hdr +./LDR/7U6A8648.CR2 ./HDR/train/7U6A8647_hdr +./LDR/7U6A8654.CR2 ./HDR/train/7U6A8653_hdr +./LDR/7U6A8660.CR2 ./HDR/train/7U6A8659_hdr +./LDR/7U6A8663.CR2 ./HDR/train/7U6A8662_hdr +./LDR/7U6A8678.CR2 ./HDR/train/7U6A8677_hdr +./LDR/7U6A8684.CR2 ./HDR/train/7U6A8683_hdr +./LDR/7U6A8690.CR2 ./HDR/train/7U6A8689_hdr +./LDR/7U6A8693.CR2 ./HDR/train/7U6A8692_hdr +./LDR/7U6A8699.CR2 ./HDR/train/7U6A8698_hdr +./LDR/7U6A8702.CR2 ./HDR/train/7U6A8701_hdr +./LDR/7U6A8705.CR2 ./HDR/train/7U6A8704_hdr +./LDR/7U6A8711.CR2 ./HDR/train/7U6A8710_hdr +./LDR/7U6A8714.CR2 ./HDR/train/7U6A8713_hdr +./LDR/7U6A8720.CR2 ./HDR/train/7U6A8719_hdr +./LDR/7U6A8723.CR2 ./HDR/train/7U6A8722_hdr +./LDR/7U6A8729.CR2 ./HDR/train/7U6A8728_hdr +./LDR/7U6A8735.CR2 ./HDR/train/7U6A8734_hdr +./LDR/7U6A8738.CR2 ./HDR/train/7U6A8737_hdr +./LDR/7U6A8744.CR2 ./HDR/train/7U6A8743_hdr +./LDR/7U6A8747.CR2 ./HDR/train/7U6A8746_hdr +./LDR/7U6A8750.CR2 ./HDR/train/7U6A8749_hdr +./LDR/7U6A8753.CR2 ./HDR/train/7U6A8752_hdr +./LDR/7U6A8762.CR2 ./HDR/train/7U6A8761_hdr +./LDR/7U6A8774.CR2 ./HDR/train/7U6A8773_hdr +./LDR/7U6A8786.CR2 ./HDR/train/7U6A8785_hdr +./LDR/7U6A8789.CR2 ./HDR/train/7U6A8788_hdr +./LDR/7U6A8884.CR2 ./HDR/train/7U6A8883_hdr +./LDR/7U6A8887.CR2 ./HDR/train/7U6A8886_hdr +./LDR/7U6A8896.CR2 ./HDR/train/7U6A8895_hdr +./LDR/7U6A8899.CR2 ./HDR/train/7U6A8898_hdr +./LDR/7U6A8905.CR2 ./HDR/train/7U6A8904_hdr +./LDR/7U6A8908.CR2 ./HDR/train/7U6A8907_hdr +./LDR/7U6A8917.CR2 ./HDR/train/7U6A8916_hdr +./LDR/7U6A8920.CR2 ./HDR/train/7U6A8919_hdr +./LDR/7U6A8926.CR2 ./HDR/train/7U6A8925_hdr +./LDR/7U6A8929.CR2 ./HDR/train/7U6A8928_hdr +./LDR/7U6A8935.CR2 ./HDR/train/7U6A8934_hdr +./LDR/7U6A8938.CR2 ./HDR/train/7U6A8937_hdr +./LDR/7U6A8944.CR2 ./HDR/train/7U6A8943_hdr +./LDR/7U6A8947.CR2 ./HDR/train/7U6A8946_hdr +./LDR/7U6A8953.CR2 ./HDR/train/7U6A8952_hdr +./LDR/7U6A8959.CR2 ./HDR/train/7U6A8958_hdr +./LDR/7U6A8965.CR2 ./HDR/train/7U6A8964_hdr +./LDR/7U6A8968.CR2 ./HDR/train/7U6A8967_hdr +./LDR/7U6A8976.CR2 ./HDR/train/7U6A8975_hdr +./LDR/7U6A8985.CR2 ./HDR/train/7U6A8984_hdr +./LDR/7U6A8991.CR2 ./HDR/train/7U6A8990_hdr +./LDR/7U6A8994.CR2 ./HDR/train/7U6A8993_hdr +./LDR/7U6A8997.CR2 ./HDR/train/7U6A8996_hdr +./LDR/7U6A9006.CR2 ./HDR/train/7U6A9005_hdr +./LDR/7U6A9009.CR2 ./HDR/train/7U6A9008_hdr +./LDR/7U6A9016.CR2 ./HDR/train/7U6A9015_hdr +./LDR/7U6A9022.CR2 ./HDR/train/7U6A9021_hdr +./LDR/7U6A9028.CR2 ./HDR/train/7U6A9027_hdr +./LDR/7U6A9031.CR2 ./HDR/train/7U6A9030_hdr +./LDR/7U6A9061.CR2 ./HDR/train/7U6A9060_hdr +./LDR/7U6A9064.CR2 ./HDR/train/7U6A9063_hdr +./LDR/7U6A9067.CR2 ./HDR/train/7U6A9066_hdr +./LDR/7U6A9079.CR2 ./HDR/train/7U6A9078_hdr +./LDR/7U6A9082.CR2 ./HDR/train/7U6A9081_hdr +./LDR/7U6A9088.CR2 ./HDR/train/7U6A9087_hdr +./LDR/7U6A9091.CR2 ./HDR/train/7U6A9090_hdr +./LDR/7U6A9094.CR2 ./HDR/train/7U6A9093_hdr +./LDR/7U6A9115.CR2 ./HDR/train/7U6A9114_hdr +./LDR/7U6A9121.CR2 ./HDR/train/7U6A9120_hdr +./LDR/7U6A9136.CR2 ./HDR/train/7U6A9135_hdr +./LDR/7U6A9139.CR2 ./HDR/train/7U6A9138_hdr +./LDR/7U6A9142.CR2 ./HDR/train/7U6A9141_hdr +./LDR/7U6A9145.CR2 ./HDR/train/7U6A9144_hdr +./LDR/7U6A9151.CR2 ./HDR/train/7U6A9150_hdr +./LDR/7U6A9160.CR2 ./HDR/train/7U6A9159_hdr +./LDR/7U6A9163.CR2 ./HDR/train/7U6A9162_hdr +./LDR/7U6A9172.CR2 ./HDR/train/7U6A9171_hdr +./LDR/7U6A9181.CR2 ./HDR/train/7U6A9180_hdr +./LDR/7U6A9184.CR2 ./HDR/train/7U6A9183_hdr +./LDR/7U6A9187.CR2 ./HDR/train/7U6A9186_hdr +./LDR/7U6A9190.CR2 ./HDR/train/7U6A9189_hdr +./LDR/7U6A9196.CR2 ./HDR/train/7U6A9195_hdr +./LDR/7U6A9208.CR2 ./HDR/train/7U6A9207_hdr +./LDR/7U6A9214.CR2 ./HDR/train/7U6A9213_hdr +./LDR/7U6A9220.CR2 ./HDR/train/7U6A9219_hdr +./LDR/7U6A9229.CR2 ./HDR/train/7U6A9228_hdr +./LDR/7U6A9238.CR2 ./HDR/train/7U6A9237_hdr +./LDR/7U6A9244.CR2 ./HDR/train/7U6A9243_hdr +./LDR/7U6A9250.CR2 ./HDR/train/7U6A9249_hdr +./LDR/7U6A9253.CR2 ./HDR/train/7U6A9252_hdr +./LDR/7U6A9259.CR2 ./HDR/train/7U6A9258_hdr +./LDR/7U6A9262.CR2 ./HDR/train/7U6A9261_hdr +./LDR/7U6A9265.CR2 ./HDR/train/7U6A9264_hdr +./LDR/7U6A9268.CR2 ./HDR/train/7U6A9267_hdr +./LDR/7U6A9274.CR2 ./HDR/train/7U6A9273_hdr +./LDR/7U6A9277.CR2 ./HDR/train/7U6A9276_hdr +./LDR/7U6A9280.CR2 ./HDR/train/7U6A9279_hdr +./LDR/7U6A9283.CR2 ./HDR/train/7U6A9282_hdr +./LDR/7U6A9304.CR2 ./HDR/train/7U6A9303_hdr +./LDR/7U6A9307.CR2 ./HDR/train/7U6A9306_hdr +./LDR/7U6A9310.CR2 ./HDR/train/7U6A9309_hdr +./LDR/7U6A9313.CR2 ./HDR/train/7U6A9312_hdr +./LDR/7U6A9319.CR2 ./HDR/train/7U6A9318_hdr +./LDR/7U6A9325.CR2 ./HDR/train/7U6A9324_hdr +./LDR/7U6A9331.CR2 ./HDR/train/7U6A9330_hdr +./LDR/7U6A9340.CR2 ./HDR/train/7U6A9339_hdr +./LDR/7U6A9352.CR2 ./HDR/train/7U6A9351_hdr +./LDR/7U6A9370.CR2 ./HDR/train/7U6A9369_hdr +./LDR/7U6A9376.CR2 ./HDR/train/7U6A9375_hdr +./LDR/7U6A9379.CR2 ./HDR/train/7U6A9378_hdr +./LDR/7U6A9382.CR2 ./HDR/train/7U6A9381_hdr +./LDR/7U6A9385.CR2 ./HDR/train/7U6A9384_hdr +./LDR/7U6A9400.CR2 ./HDR/train/7U6A9399_hdr +./LDR/7U6A9403.CR2 ./HDR/train/7U6A9402_hdr diff --git a/datasets/txtfiles/HDR/train1.txt b/datasets/txtfiles/HDR/train1.txt new file mode 100644 index 0000000..6e58fd4 --- /dev/null +++ b/datasets/txtfiles/HDR/train1.txt @@ -0,0 +1,226 @@ +./LDR/7U6A3449.CR2 ./HDR/train/7U6A3449_hdr +./LDR/7U6A3452.CR2 ./HDR/train/7U6A3452_hdr +./LDR/7U6A3455.CR2 ./HDR/train/7U6A3455_hdr +./LDR/7U6A3458.CR2 ./HDR/train/7U6A3458_hdr +./LDR/7U6A3461.CR2 ./HDR/train/7U6A3461_hdr +./LDR/7U6A3473.CR2 ./HDR/train/7U6A3473_hdr +./LDR/7U6A3476.CR2 ./HDR/train/7U6A3476_hdr +./LDR/7U6A3479.CR2 ./HDR/train/7U6A3479_hdr +./LDR/7U6A3494.CR2 ./HDR/train/7U6A3494_hdr +./LDR/7U6A3512.CR2 ./HDR/train/7U6A3512_hdr +./LDR/7U6A8059.CR2 ./HDR/train/7U6A8059_hdr +./LDR/7U6A8068.CR2 ./HDR/train/7U6A8068_hdr +./LDR/7U6A8074.CR2 ./HDR/train/7U6A8074_hdr +./LDR/7U6A8080.CR2 ./HDR/train/7U6A8080_hdr +./LDR/7U6A8083.CR2 ./HDR/train/7U6A8083_hdr +./LDR/7U6A8086.CR2 ./HDR/train/7U6A8086_hdr +./LDR/7U6A8095.CR2 ./HDR/train/7U6A8095_hdr +./LDR/7U6A8107.CR2 ./HDR/train/7U6A8107_hdr +./LDR/7U6A8110.CR2 ./HDR/train/7U6A8110_hdr +./LDR/7U6A8125.CR2 ./HDR/train/7U6A8125_hdr +./LDR/7U6A8128.CR2 ./HDR/train/7U6A8128_hdr +./LDR/7U6A8134.CR2 ./HDR/train/7U6A8134_hdr +./LDR/7U6A8140.CR2 ./HDR/train/7U6A8140_hdr +./LDR/7U6A8143.CR2 ./HDR/train/7U6A8143_hdr +./LDR/7U6A8158.CR2 ./HDR/train/7U6A8158_hdr +./LDR/7U6A8167.CR2 ./HDR/train/7U6A8167_hdr +./LDR/7U6A8173.CR2 ./HDR/train/7U6A8173_hdr +./LDR/7U6A8176.CR2 ./HDR/train/7U6A8176_hdr +./LDR/7U6A8179.CR2 ./HDR/train/7U6A8179_hdr +./LDR/7U6A8188.CR2 ./HDR/train/7U6A8188_hdr +./LDR/7U6A8191.CR2 ./HDR/train/7U6A8191_hdr +./LDR/7U6A8194.CR2 ./HDR/train/7U6A8194_hdr +./LDR/7U6A8200.CR2 ./HDR/train/7U6A8200_hdr +./LDR/7U6A8206.CR2 ./HDR/train/7U6A8206_hdr +./LDR/7U6A8218.CR2 ./HDR/train/7U6A8218_hdr +./LDR/7U6A8221.CR2 ./HDR/train/7U6A8221_hdr +./LDR/7U6A8224.CR2 ./HDR/train/7U6A8224_hdr +./LDR/7U6A8230.CR2 ./HDR/train/7U6A8230_hdr +./LDR/7U6A8233.CR2 ./HDR/train/7U6A8233_hdr +./LDR/7U6A8239.CR2 ./HDR/train/7U6A8239_hdr +./LDR/7U6A8248.CR2 ./HDR/train/7U6A8248_hdr +./LDR/7U6A8251.CR2 ./HDR/train/7U6A8251_hdr +./LDR/7U6A8257.CR2 ./HDR/train/7U6A8257_hdr +./LDR/7U6A8263.CR2 ./HDR/train/7U6A8263_hdr +./LDR/7U6A8269.CR2 ./HDR/train/7U6A8269_hdr +./LDR/7U6A8272.CR2 ./HDR/train/7U6A8272_hdr +./LDR/7U6A8275.CR2 ./HDR/train/7U6A8275_hdr +./LDR/7U6A8278.CR2 ./HDR/train/7U6A8278_hdr +./LDR/7U6A8284.CR2 ./HDR/train/7U6A8284_hdr +./LDR/7U6A8293.CR2 ./HDR/train/7U6A8293_hdr +./LDR/7U6A8299.CR2 ./HDR/train/7U6A8299_hdr +./LDR/7U6A8302.CR2 ./HDR/train/7U6A8302_hdr +./LDR/7U6A8314.CR2 ./HDR/train/7U6A8314_hdr +./LDR/7U6A8317.CR2 ./HDR/train/7U6A8317_hdr +./LDR/7U6A8320.CR2 ./HDR/train/7U6A8320_hdr +./LDR/7U6A8329.CR2 ./HDR/train/7U6A8329_hdr +./LDR/7U6A8335.CR2 ./HDR/train/7U6A8335_hdr +./LDR/7U6A8338.CR2 ./HDR/train/7U6A8338_hdr +./LDR/7U6A8347.CR2 ./HDR/train/7U6A8347_hdr +./LDR/7U6A8350.CR2 ./HDR/train/7U6A8350_hdr +./LDR/7U6A8353.CR2 ./HDR/train/7U6A8353_hdr +./LDR/7U6A8356.CR2 ./HDR/train/7U6A8356_hdr +./LDR/7U6A8359.CR2 ./HDR/train/7U6A8359_hdr +./LDR/7U6A8365.CR2 ./HDR/train/7U6A8365_hdr +./LDR/7U6A8371.CR2 ./HDR/train/7U6A8371_hdr +./LDR/7U6A8380.CR2 ./HDR/train/7U6A8380_hdr +./LDR/7U6A8383.CR2 ./HDR/train/7U6A8383_hdr +./LDR/7U6A8394.CR2 ./HDR/train/7U6A8394_hdr +./LDR/7U6A8397.CR2 ./HDR/train/7U6A8397_hdr +./LDR/7U6A8403.CR2 ./HDR/train/7U6A8403_hdr +./LDR/7U6A8406.CR2 ./HDR/train/7U6A8406_hdr +./LDR/7U6A8412.CR2 ./HDR/train/7U6A8412_hdr +./LDR/7U6A8415.CR2 ./HDR/train/7U6A8415_hdr +./LDR/7U6A8418.CR2 ./HDR/train/7U6A8418_hdr +./LDR/7U6A8421.CR2 ./HDR/train/7U6A8421_hdr +./LDR/7U6A8427.CR2 ./HDR/train/7U6A8427_hdr +./LDR/7U6A8433.CR2 ./HDR/train/7U6A8433_hdr +./LDR/7U6A8439.CR2 ./HDR/train/7U6A8439_hdr +./LDR/7U6A8442.CR2 ./HDR/train/7U6A8442_hdr +./LDR/7U6A8445.CR2 ./HDR/train/7U6A8445_hdr +./LDR/7U6A8451.CR2 ./HDR/train/7U6A8451_hdr +./LDR/7U6A8454.CR2 ./HDR/train/7U6A8454_hdr +./LDR/7U6A8457.CR2 ./HDR/train/7U6A8457_hdr +./LDR/7U6A8460.CR2 ./HDR/train/7U6A8460_hdr +./LDR/7U6A8472.CR2 ./HDR/train/7U6A8472_hdr +./LDR/7U6A8478.CR2 ./HDR/train/7U6A8478_hdr +./LDR/7U6A8481.CR2 ./HDR/train/7U6A8481_hdr +./LDR/7U6A8490.CR2 ./HDR/train/7U6A8490_hdr +./LDR/7U6A8493.CR2 ./HDR/train/7U6A8493_hdr +./LDR/7U6A8496.CR2 ./HDR/train/7U6A8496_hdr +./LDR/7U6A8499.CR2 ./HDR/train/7U6A8499_hdr +./LDR/7U6A8502.CR2 ./HDR/train/7U6A8502_hdr +./LDR/7U6A8505.CR2 ./HDR/train/7U6A8505_hdr +./LDR/7U6A8508.CR2 ./HDR/train/7U6A8508_hdr +./LDR/7U6A8511.CR2 ./HDR/train/7U6A8511_hdr +./LDR/7U6A8517.CR2 ./HDR/train/7U6A8517_hdr +./LDR/7U6A8520.CR2 ./HDR/train/7U6A8520_hdr +./LDR/7U6A8529.CR2 ./HDR/train/7U6A8529_hdr +./LDR/7U6A8538.CR2 ./HDR/train/7U6A8538_hdr +./LDR/7U6A8541.CR2 ./HDR/train/7U6A8541_hdr +./LDR/7U6A8544.CR2 ./HDR/train/7U6A8544_hdr +./LDR/7U6A8547.CR2 ./HDR/train/7U6A8547_hdr +./LDR/7U6A8550.CR2 ./HDR/train/7U6A8550_hdr +./LDR/7U6A8553.CR2 ./HDR/train/7U6A8553_hdr +./LDR/7U6A8559.CR2 ./HDR/train/7U6A8559_hdr +./LDR/7U6A8574.CR2 ./HDR/train/7U6A8574_hdr +./LDR/7U6A8577.CR2 ./HDR/train/7U6A8577_hdr +./LDR/7U6A8595.CR2 ./HDR/train/7U6A8595_hdr +./LDR/7U6A8598.CR2 ./HDR/train/7U6A8598_hdr +./LDR/7U6A8601.CR2 ./HDR/train/7U6A8601_hdr +./LDR/7U6A8604.CR2 ./HDR/train/7U6A8604_hdr +./LDR/7U6A8608.CR2 ./HDR/train/7U6A8608_hdr +./LDR/7U6A8632.CR2 ./HDR/train/7U6A8632_hdr +./LDR/7U6A8635.CR2 ./HDR/train/7U6A8635_hdr +./LDR/7U6A8641.CR2 ./HDR/train/7U6A8641_hdr +./LDR/7U6A8644.CR2 ./HDR/train/7U6A8644_hdr +./LDR/7U6A8647.CR2 ./HDR/train/7U6A8647_hdr +./LDR/7U6A8653.CR2 ./HDR/train/7U6A8653_hdr +./LDR/7U6A8659.CR2 ./HDR/train/7U6A8659_hdr +./LDR/7U6A8662.CR2 ./HDR/train/7U6A8662_hdr +./LDR/7U6A8677.CR2 ./HDR/train/7U6A8677_hdr +./LDR/7U6A8683.CR2 ./HDR/train/7U6A8683_hdr +./LDR/7U6A8689.CR2 ./HDR/train/7U6A8689_hdr +./LDR/7U6A8692.CR2 ./HDR/train/7U6A8692_hdr +./LDR/7U6A8698.CR2 ./HDR/train/7U6A8698_hdr +./LDR/7U6A8701.CR2 ./HDR/train/7U6A8701_hdr +./LDR/7U6A8704.CR2 ./HDR/train/7U6A8704_hdr +./LDR/7U6A8710.CR2 ./HDR/train/7U6A8710_hdr +./LDR/7U6A8713.CR2 ./HDR/train/7U6A8713_hdr +./LDR/7U6A8719.CR2 ./HDR/train/7U6A8719_hdr +./LDR/7U6A8722.CR2 ./HDR/train/7U6A8722_hdr +./LDR/7U6A8728.CR2 ./HDR/train/7U6A8728_hdr +./LDR/7U6A8734.CR2 ./HDR/train/7U6A8734_hdr +./LDR/7U6A8737.CR2 ./HDR/train/7U6A8737_hdr +./LDR/7U6A8743.CR2 ./HDR/train/7U6A8743_hdr +./LDR/7U6A8746.CR2 ./HDR/train/7U6A8746_hdr +./LDR/7U6A8749.CR2 ./HDR/train/7U6A8749_hdr +./LDR/7U6A8752.CR2 ./HDR/train/7U6A8752_hdr +./LDR/7U6A8761.CR2 ./HDR/train/7U6A8761_hdr +./LDR/7U6A8773.CR2 ./HDR/train/7U6A8773_hdr +./LDR/7U6A8785.CR2 ./HDR/train/7U6A8785_hdr +./LDR/7U6A8788.CR2 ./HDR/train/7U6A8788_hdr +./LDR/7U6A8883.CR2 ./HDR/train/7U6A8883_hdr +./LDR/7U6A8886.CR2 ./HDR/train/7U6A8886_hdr +./LDR/7U6A8895.CR2 ./HDR/train/7U6A8895_hdr +./LDR/7U6A8898.CR2 ./HDR/train/7U6A8898_hdr +./LDR/7U6A8904.CR2 ./HDR/train/7U6A8904_hdr +./LDR/7U6A8907.CR2 ./HDR/train/7U6A8907_hdr +./LDR/7U6A8916.CR2 ./HDR/train/7U6A8916_hdr +./LDR/7U6A8919.CR2 ./HDR/train/7U6A8919_hdr +./LDR/7U6A8925.CR2 ./HDR/train/7U6A8925_hdr +./LDR/7U6A8928.CR2 ./HDR/train/7U6A8928_hdr +./LDR/7U6A8934.CR2 ./HDR/train/7U6A8934_hdr +./LDR/7U6A8937.CR2 ./HDR/train/7U6A8937_hdr +./LDR/7U6A8943.CR2 ./HDR/train/7U6A8943_hdr +./LDR/7U6A8946.CR2 ./HDR/train/7U6A8946_hdr +./LDR/7U6A8952.CR2 ./HDR/train/7U6A8952_hdr +./LDR/7U6A8958.CR2 ./HDR/train/7U6A8958_hdr +./LDR/7U6A8964.CR2 ./HDR/train/7U6A8964_hdr +./LDR/7U6A8967.CR2 ./HDR/train/7U6A8967_hdr +./LDR/7U6A8975.CR2 ./HDR/train/7U6A8975_hdr +./LDR/7U6A8984.CR2 ./HDR/train/7U6A8984_hdr +./LDR/7U6A8990.CR2 ./HDR/train/7U6A8990_hdr +./LDR/7U6A8993.CR2 ./HDR/train/7U6A8993_hdr +./LDR/7U6A8996.CR2 ./HDR/train/7U6A8996_hdr +./LDR/7U6A9005.CR2 ./HDR/train/7U6A9005_hdr +./LDR/7U6A9008.CR2 ./HDR/train/7U6A9008_hdr +./LDR/7U6A9015.CR2 ./HDR/train/7U6A9015_hdr +./LDR/7U6A9021.CR2 ./HDR/train/7U6A9021_hdr +./LDR/7U6A9027.CR2 ./HDR/train/7U6A9027_hdr +./LDR/7U6A9030.CR2 ./HDR/train/7U6A9030_hdr +./LDR/7U6A9060.CR2 ./HDR/train/7U6A9060_hdr +./LDR/7U6A9063.CR2 ./HDR/train/7U6A9063_hdr +./LDR/7U6A9066.CR2 ./HDR/train/7U6A9066_hdr +./LDR/7U6A9078.CR2 ./HDR/train/7U6A9078_hdr +./LDR/7U6A9081.CR2 ./HDR/train/7U6A9081_hdr +./LDR/7U6A9087.CR2 ./HDR/train/7U6A9087_hdr +./LDR/7U6A9090.CR2 ./HDR/train/7U6A9090_hdr +./LDR/7U6A9093.CR2 ./HDR/train/7U6A9093_hdr +./LDR/7U6A9114.CR2 ./HDR/train/7U6A9114_hdr +./LDR/7U6A9120.CR2 ./HDR/train/7U6A9120_hdr +./LDR/7U6A9135.CR2 ./HDR/train/7U6A9135_hdr +./LDR/7U6A9138.CR2 ./HDR/train/7U6A9138_hdr +./LDR/7U6A9141.CR2 ./HDR/train/7U6A9141_hdr +./LDR/7U6A9144.CR2 ./HDR/train/7U6A9144_hdr +./LDR/7U6A9150.CR2 ./HDR/train/7U6A9150_hdr +./LDR/7U6A9159.CR2 ./HDR/train/7U6A9159_hdr +./LDR/7U6A9162.CR2 ./HDR/train/7U6A9162_hdr +./LDR/7U6A9171.CR2 ./HDR/train/7U6A9171_hdr +./LDR/7U6A9180.CR2 ./HDR/train/7U6A9180_hdr +./LDR/7U6A9183.CR2 ./HDR/train/7U6A9183_hdr +./LDR/7U6A9186.CR2 ./HDR/train/7U6A9186_hdr +./LDR/7U6A9189.CR2 ./HDR/train/7U6A9189_hdr +./LDR/7U6A9195.CR2 ./HDR/train/7U6A9195_hdr +./LDR/7U6A9207.CR2 ./HDR/train/7U6A9207_hdr +./LDR/7U6A9213.CR2 ./HDR/train/7U6A9213_hdr +./LDR/7U6A9219.CR2 ./HDR/train/7U6A9219_hdr +./LDR/7U6A9228.CR2 ./HDR/train/7U6A9228_hdr +./LDR/7U6A9237.CR2 ./HDR/train/7U6A9237_hdr +./LDR/7U6A9243.CR2 ./HDR/train/7U6A9243_hdr +./LDR/7U6A9249.CR2 ./HDR/train/7U6A9249_hdr +./LDR/7U6A9252.CR2 ./HDR/train/7U6A9252_hdr +./LDR/7U6A9258.CR2 ./HDR/train/7U6A9258_hdr +./LDR/7U6A9261.CR2 ./HDR/train/7U6A9261_hdr +./LDR/7U6A9264.CR2 ./HDR/train/7U6A9264_hdr +./LDR/7U6A9267.CR2 ./HDR/train/7U6A9267_hdr +./LDR/7U6A9273.CR2 ./HDR/train/7U6A9273_hdr +./LDR/7U6A9276.CR2 ./HDR/train/7U6A9276_hdr +./LDR/7U6A9279.CR2 ./HDR/train/7U6A9279_hdr +./LDR/7U6A9282.CR2 ./HDR/train/7U6A9282_hdr +./LDR/7U6A9303.CR2 ./HDR/train/7U6A9303_hdr +./LDR/7U6A9306.CR2 ./HDR/train/7U6A9306_hdr +./LDR/7U6A9309.CR2 ./HDR/train/7U6A9309_hdr +./LDR/7U6A9312.CR2 ./HDR/train/7U6A9312_hdr +./LDR/7U6A9318.CR2 ./HDR/train/7U6A9318_hdr +./LDR/7U6A9324.CR2 ./HDR/train/7U6A9324_hdr +./LDR/7U6A9330.CR2 ./HDR/train/7U6A9330_hdr +./LDR/7U6A9339.CR2 ./HDR/train/7U6A9339_hdr +./LDR/7U6A9351.CR2 ./HDR/train/7U6A9351_hdr +./LDR/7U6A9369.CR2 ./HDR/train/7U6A9369_hdr +./LDR/7U6A9375.CR2 ./HDR/train/7U6A9375_hdr +./LDR/7U6A9378.CR2 ./HDR/train/7U6A9378_hdr +./LDR/7U6A9381.CR2 ./HDR/train/7U6A9381_hdr +./LDR/7U6A9384.CR2 ./HDR/train/7U6A9384_hdr +./LDR/7U6A9399.CR2 ./HDR/train/7U6A9399_hdr +./LDR/7U6A9402.CR2 ./HDR/train/7U6A9402_hdr From 58f8b91d730ca3bf9383b718ab4d901b3af1f9f7 Mon Sep 17 00:00:00 2001 From: mya012 <96360296+mya012@users.noreply.github.com> Date: Mon, 20 Oct 2025 11:28:45 +0800 Subject: [PATCH 2/6] Add files via upload --- options/base/network_g/cunet.yaml | 5 +++++ options/base/network_g/unet41.yaml | 5 +++++ 2 files changed, 10 insertions(+) create mode 100644 options/base/network_g/cunet.yaml create mode 100644 options/base/network_g/unet41.yaml diff --git a/options/base/network_g/cunet.yaml b/options/base/network_g/cunet.yaml new file mode 100644 index 0000000..d41faf5 --- /dev/null +++ b/options/base/network_g/cunet.yaml @@ -0,0 +1,5 @@ +network_g: + type: CUNetArch + inchannels: 4 + outchannels: 4 + channels: 32 \ No newline at end of file diff --git a/options/base/network_g/unet41.yaml b/options/base/network_g/unet41.yaml new file mode 100644 index 0000000..5d42d6c --- /dev/null +++ b/options/base/network_g/unet41.yaml @@ -0,0 +1,5 @@ +network_g: + type: UNetArch + inchannels: 4 + outchannels: 1 + channels: 32 \ No newline at end of file From 17d50e0d184801b73d42aebceddbc9ba3e67e842 Mon Sep 17 00:00:00 2001 From: mya012 <96360296+mya012@users.noreply.github.com> Date: Mon, 20 Oct 2025 11:29:27 +0800 Subject: [PATCH 3/6] Add files via upload --- options/UltraLED/train_step1.yaml | 129 +++++++++++++++++++++++++++++ options/UltraLED/train_step2.yaml | 131 ++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 options/UltraLED/train_step1.yaml create mode 100644 options/UltraLED/train_step2.yaml diff --git a/options/UltraLED/train_step1.yaml b/options/UltraLED/train_step1.yaml new file mode 100644 index 0000000..46350c2 --- /dev/null +++ b/options/UltraLED/train_step1.yaml @@ -0,0 +1,129 @@ +# general settings +name: Train_step1 +model_type: RatioMapEstimatorModel +scale: 1 +num_gpu: 1 +manual_seed: 2022 +metric_in_srgb: false +CRF_path: datasets/EMoR + +# TODO: update the path to the dataroot +# dataset and data loader settings +datasets: + train: + name: EstimatorTrain + type: EstimatorHDRRAWDataset + dataroot: path/to/RAWHDRnpz + postfix: npz + which_meta: gt + data_pair_list: datasets/txtfiles/HDR/train.txt + zero_clip: false + crop_size: ~ + load_in_mem: true + + ratio_range: [1, 31] + noise_type: ptrqc + + + use_patches: false + + use_hflip: true + use_rot: true + crop_size: 512 + load_in_mem: false + + # data loadecr + num_worker_per_gpu: 8 + batch_size_per_gpu: 1 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + +noise_g: + noise_type: ptrqc + camera_params: + SonyA7M4: + Kmin: 0.5176822379876143 + Kmax: 37.48342111254088 + Row: + slope: 0.8498868627654588 + bias: -2.508114778989769 + sigma: 0.2539644785570933 + Gaussian: + slope: 0.9469078551868472 + bias: 0.16045036103138344 + sigma: 0.20521559771812753 + TurkeyLambda: + slope: 0.8741245752046553 + bias: -0.7834371464250983 + sigma: 0.24359869793646144 + lambda: [-0.09686645120382309, -0.20266661047935486, -0.21817433834075928, -0.2772795557975769, -0.26944655179977417, -0.271323561668396, -0.2704177498817444, -0.2713994085788727, -0.27039051055908203, -0.2710208296775818, -0.2720312476158142, -0.2751300632953644, -0.2763080298900604, -0.27588126063346863, -0.2749999165534973, -0.27548256516456604, -0.27511167526245117, -0.276935875415802, -0.28096914291381836, -0.2802513539791107, -0.28123438358306885, -0.28208523988723755, -0.28266826272010803, -0.2647216022014618, -0.26392486691474915, -0.26441627740859985, -0.2643108367919922, -0.2627350389957428, -0.26077425479888916, -0.2635541260242462] + ColorBias: [[0.23521856874227523, 0.22085916072130204, 0.2863488158583641, 0.22595973789691925], [0.3256998673081398, 0.3030224359035492, 0.34010037809610366, 0.3117868521809578], [0.2417914468050003, 0.23460582494735718, 0.2810402739048004, 0.2346685606241226], [0.2893064874410629, 0.2924955692887306, 0.3213977232575417, 0.32337920874357223], [0.25456790179014205, 0.2652137302979827, 0.304450766146183, 0.29432806588709354], [0.1782449632883072, 0.30211541056632996, 0.2497684806585312, 0.24955467879772186], [0.43470048904418945, 0.4313555061817169, 0.34231114387512207, 0.35701221227645874], [0.34076905250549316, -0.01008229423314333, 0.42228585481643677, 0.144450843334198], [0.09368899464607239, 0.10415662825107574, 0.3285943269729614, 0.05827445909380913], [0.13945339620113373, 0.23084817826747894, 0.13442949950695038, 0.28657567501068115], [0.252593609985197, 0.2493764951825142, 0.612515584230423, 0.6146858245134353], [0.16977471113204956, -0.08632060885429382, 0.5320069193840027, 0.3316039741039276], [0.3789997398853302, 0.4710197448730469, 0.5400363206863403, 0.7923370003700256], [0.3004434108734131, 0.2512214779853821, 0.7026330232620239, 0.65498948097229], [0.3602710623666644, 0.3627533086016774, 0.8907627087831497, 1.0329724252223969], [-0.18567052483558655, 0.4224573075771332, 0.8947332501411438, 1.0053366422653198], [0.6430071445275098, 0.3353357759770006, 2.0499183797836302, 1.9346534669399262], [0.960235595703125, -0.4397820830345154, 1.5321524143218994, 1.1829473972320557]] + engine: torch + + +# network structures +network_g: + type: UNetArch + inchannels: 4 + outchannels: 1 + channels: 32 + +network_d: + type: UNetArch + inchannels: 4 + outchannels: 1 + channels: 32 + +# path to the pretrained models +network_d_path: ~ + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +train: + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.999] + + scheduler: + type: HandieLR + milestones: [20000, 24000] + lrs: [9.0e-5, 8.0e-5] + + total_iter: 25000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: RAWL1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 9999999999999 + save_img: false + calculate_metric_in_batch: true + illumination_correct: true + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 2 + test_y_channel: false + ssim: # metric name, can be arbitrary + type: calculate_ssim + crop_border: 2 + test_y_channel: false + +# logging settings +logger: + print_freq: 200 + save_checkpoint_freq: !!float 1288 + use_tb_logger: true \ No newline at end of file diff --git a/options/UltraLED/train_step2.yaml b/options/UltraLED/train_step2.yaml new file mode 100644 index 0000000..f1cc1d1 --- /dev/null +++ b/options/UltraLED/train_step2.yaml @@ -0,0 +1,131 @@ +# general settings +name: Train_step2 +model_type: RAWDenoiserModel +scale: 1 +num_gpu: 1 +manual_seed: 2022 +metric_in_srgb: false +CRF_path: datasets/EMoR + +# TODO: update the path to the dataroot +# dataset and data loader settings +datasets: + train: + name: DenoiserTrain + type: DenoiserHDRRAWDataset + dataroot: path/to/RAWHDRnpz + postfix: npz + which_meta: gt + data_pair_list: datasets/txtfiles/HDR/train.txt + zero_clip: false + crop_size: ~ + load_in_mem: true + + ratio_range: [1, 31] + noise_type: ptrqc + + + use_patches: false + + use_hflip: true + use_rot: true + crop_size: 512 + load_in_mem: false + + # data loadecr + num_worker_per_gpu: 10 + batch_size_per_gpu: 8 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + +noise_g: + noise_type: ptrqc + camera_params: + SonyA7M4: + Kmin: 0.5176822379876143 + Kmax: 37.48342111254088 + Row: + slope: 0.8498868627654588 + bias: -2.508114778989769 + sigma: 0.2539644785570933 + Gaussian: + slope: 0.9469078551868472 + bias: 0.16045036103138344 + sigma: 0.20521559771812753 + TurkeyLambda: + slope: 0.8741245752046553 + bias: -0.7834371464250983 + sigma: 0.24359869793646144 + lambda: [-0.09686645120382309, -0.20266661047935486, -0.21817433834075928, -0.2772795557975769, -0.26944655179977417, -0.271323561668396, -0.2704177498817444, -0.2713994085788727, -0.27039051055908203, -0.2710208296775818, -0.2720312476158142, -0.2751300632953644, -0.2763080298900604, -0.27588126063346863, -0.2749999165534973, -0.27548256516456604, -0.27511167526245117, -0.276935875415802, -0.28096914291381836, -0.2802513539791107, -0.28123438358306885, -0.28208523988723755, -0.28266826272010803, -0.2647216022014618, -0.26392486691474915, -0.26441627740859985, -0.2643108367919922, -0.2627350389957428, -0.26077425479888916, -0.2635541260242462] + ColorBias: [[0.23521856874227523, 0.22085916072130204, 0.2863488158583641, 0.22595973789691925], [0.3256998673081398, 0.3030224359035492, 0.34010037809610366, 0.3117868521809578], [0.2417914468050003, 0.23460582494735718, 0.2810402739048004, 0.2346685606241226], [0.2893064874410629, 0.2924955692887306, 0.3213977232575417, 0.32337920874357223], [0.25456790179014205, 0.2652137302979827, 0.304450766146183, 0.29432806588709354], [0.1782449632883072, 0.30211541056632996, 0.2497684806585312, 0.24955467879772186], [0.43470048904418945, 0.4313555061817169, 0.34231114387512207, 0.35701221227645874], [0.34076905250549316, -0.01008229423314333, 0.42228585481643677, 0.144450843334198], [0.09368899464607239, 0.10415662825107574, 0.3285943269729614, 0.05827445909380913], [0.13945339620113373, 0.23084817826747894, 0.13442949950695038, 0.28657567501068115], [0.252593609985197, 0.2493764951825142, 0.612515584230423, 0.6146858245134353], [0.16977471113204956, -0.08632060885429382, 0.5320069193840027, 0.3316039741039276], [0.3789997398853302, 0.4710197448730469, 0.5400363206863403, 0.7923370003700256], [0.3004434108734131, 0.2512214779853821, 0.7026330232620239, 0.65498948097229], [0.3602710623666644, 0.3627533086016774, 0.8907627087831497, 1.0329724252223969], [-0.18567052483558655, 0.4224573075771332, 0.8947332501411438, 1.0053366422653198], [0.6430071445275098, 0.3353357759770006, 2.0499183797836302, 1.9346534669399262], [0.960235595703125, -0.4397820830345154, 1.5321524143218994, 1.1829473972320557]] + engine: torch + +# network structures +network_g: + type: CUNetArch + inchannels: 4 + outchannels: 4 + channels: 32 + +network_d: + type: UNetArch + inchannels: 4 + outchannels: 1 + channels: 32 + +# TODO: update the path to the pretrained ratio map estimator models +# path to the pretrained ratio map estimator models +network_d_path: path/to/the_pretrained_ratio_map_estimator_model + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.999] + + scheduler: + type: CosineAnnealingRestartLR + eta_min: !!float 1e-5 + periods: [96600, 193200, 289800] + restart_weights: [1, 0.5, 0.25] + + total_iter: 270480 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 9999999999999 + save_img: false + calculate_metric_in_batch: true + illumination_correct: true + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 2 + test_y_channel: false + ssim: # metric name, can be arbitrary + type: calculate_ssim + crop_border: 2 + test_y_channel: false + +# logging settings +logger: + print_freq: 200 + save_checkpoint_freq: !!float 1288 + use_tb_logger: true From 47e39d817fa84b082e1f94eac28a7d471e3ef433 Mon Sep 17 00:00:00 2001 From: mya012 <96360296+mya012@users.noreply.github.com> Date: Mon, 20 Oct 2025 11:30:18 +0800 Subject: [PATCH 4/6] Add files via upload --- scripts/image_process_ultraled.py | 137 ++++++++++++++++++++++++++++++ scripts/test_metrics_ultraled.py | 129 ++++++++++++++++++++++++++++ 2 files changed, 266 insertions(+) create mode 100644 scripts/image_process_ultraled.py create mode 100644 scripts/test_metrics_ultraled.py diff --git a/scripts/image_process_ultraled.py b/scripts/image_process_ultraled.py new file mode 100644 index 0000000..f31c133 --- /dev/null +++ b/scripts/image_process_ultraled.py @@ -0,0 +1,137 @@ +import argparse +import glob +import os +import time +from copy import deepcopy +import math + +import cv2 +import numpy as np +import rawpy +import torch +import torch.nn.functional as F +from tqdm import tqdm +import sys + +from ultraled.archs import build_network +from ultraled.utils.options import yaml_load +from ultraled.data.raw_utils import * + + +def load_network(net, load_path, strict = True, param_key = 'params'): + """Load network weights from checkpoint.""" + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + print('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + + print(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') + + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + + net.load_state_dict(load_net, strict=strict) + + +def get_available_device(): + """Get available computing device.""" + if torch.cuda.is_available(): + return torch.device('cuda') + if torch.backends.mps.is_available(): + return torch.device('mps') + return torch.device('cpu') + + + +def setup_network(network_options, pretrained_path): + """Setup and load network.""" + print('Building network...') + print(network_options) + network = build_network(yaml_load(network_options)['network_g']) + + print('Loading checkpoint...') + load_network(network, pretrained_path) + + device = get_available_device() + return network.to(device) + + +@torch.no_grad() +def image_process(): + """Main image processing pipeline.""" + parser = argparse.ArgumentParser(description='Image processing pipeline') + parser.add_argument('-p', '--pretrained_network', type=str, required=True, + help='the pretrained ratio map estimator network path.') + parser.add_argument('-pd', '--pretrained_denosing_network', type=str, required=True, + help='the pretrained network path for denoising.') + parser.add_argument('--data_path', type=str, required=True, + help='the folder containing raw images to be processed.') + parser.add_argument('--save_path', type=str, default='inference/image_process', + help='output folder for processed images.') + parser.add_argument('-opt', '--network_options', + default='options/base/network_g/cunet.yaml', + help='ratio map estimator network architecture options.') + parser.add_argument('-optd', '--denoising_network_options', + default='options/base/network_g/41unet.yaml', + help='denoising network architecture options.') + parser.add_argument('--ratio', '--dgain', type=float, default=1.0, + help='maximum exposure gain ratio.') + parser.add_argument('--target_exposure', type=float, + help='target exposure (overrides ratio).') + parser.add_argument('--bps', '--output_bps', type=int, default=8, + help='output bit depth.') + + args = parser.parse_args() + + device = get_available_device() + + network_g = setup_network(args.network_options, args.pretrained_network) + network_gd = setup_network(args.denoising_network_options, args.pretrained_denosing_network) + + raw_paths = sorted(glob.glob(f'{args.data_path}/*')) + ratio = args.ratio + os.makedirs(args.save_path, exist_ok=True) + + for raw_path in tqdm(raw_paths, desc="Processing images"): + start_time = time.time() + + if args.target_exposure is not None: + iso, exp_time = metainfo(raw_path) + ratio = args.target_exposure / (iso * exp_time) + + raw, raw_pattern, im, bl, wl = read_img(raw_path) + im0 = (im - bl) / (wl - bl) + im_normalized = im0 * ratio + + im_normalized = im_normalized.to(device) + result = network_g(im_normalized) + result = filter_bilateral(result, 15, torch.tensor(15.0).cuda(), torch.tensor(1.0).cuda()) + result = result.cpu() + + ratiomap_output = result + realmap = torch.tensor(ratio / ratiomap_output).to(device) + result = im0 * ratio / ratiomap_output + + result = result.clip(0.0, 1.0).to(device) + + ratio1 = realmap.mean().item() + result1 = network_gd(result, realmap, if_train=False) + result1 = result1.cpu().clip(0, 1) + + rgb = postprocess(raw, raw_pattern, result1, bl, wl, args.bps) + rgb_tensor = torch.FloatTensor(rgb / 1.0) + rgb_final = rgb_tensor.cpu().numpy().astype(np.uint8) + + base_save_path = raw_path.replace(args.data_path, args.save_path) + cv2.imwrite(f'{base_save_path}.png', rgb_final) + + raw.close() + + +if __name__ == '__main__': + image_process() \ No newline at end of file diff --git a/scripts/test_metrics_ultraled.py b/scripts/test_metrics_ultraled.py new file mode 100644 index 0000000..89bd4d1 --- /dev/null +++ b/scripts/test_metrics_ultraled.py @@ -0,0 +1,129 @@ +import os +import glob +import numpy as np +import cv2 +import torch +from torch import nn +import pyiqa +from tqdm import tqdm +import re +import sys + +from ultraled.data.raw_utils import * + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + +def match_image_pairs(folder_path_A, folder_path_B): + pairs = [] + + for fname in os.listdir(folder_path_A): + if not fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): + continue + + match = re.search(r'_(\d+)\.', fname) + if (match and int(match.group(1)) == 10) or (match and int(match.group(1)) == 7): + continue + + try: + prefix = fname.split('_')[0] + scene_num = int(prefix) + except: + continue + + target_fname = f"scene{scene_num}hdr.png" + target_path = os.path.join(folder_path_B, target_fname) + + if os.path.exists(target_path): + src_path = os.path.join(folder_path_A, fname) + pairs.append((src_path, target_path)) + + return pairs + +def calculate_metrics_with_correction(folder_A, folder_B): + psnr_metric = pyiqa.create_metric('psnr', device=device) + ssim_metric = pyiqa.create_metric('ssim', device=device) + lpips_metric = pyiqa.create_metric('lpips', device=device) + + image_pairs = match_image_pairs(folder_A, folder_B) + + if not image_pairs: + print("No valid image pairs found") + return + + total_psnr = 0.0 + total_ssim = 0.0 + total_lpips = 0.0 + valid_pairs = 0 + + pbar = tqdm(image_pairs, desc="Processing") + for a_path, b_path in pbar: + try: + img_a = load_image(a_path) + img_b = load_image(b_path) + + img_a = img_a / 255.0 + img_b = img_b / 255.0 + + if img_a.shape[:2] != img_b.shape[:2]: + img_b = resize_image(img_b, target_shape=img_a.shape[:2], is_mask=False) + + tensor_a = image_to_tensor(img_a).clip(0, 1).to(device) + tensor_b = image_to_tensor(img_b).clip(0, 1).to(device) + + corrected_a = illuminance_correct(tensor_a, tensor_b) + corrected_a = corrected_a.clamp(0, 1) + + with torch.no_grad(): + psnr_score = psnr_metric(corrected_a, tensor_b) + ssim_score = ssim_metric(corrected_a, tensor_b) + lpips_score = lpips_metric(corrected_a, tensor_b) + + total_psnr += psnr_score.item() + total_ssim += ssim_score.item() + total_lpips += lpips_score.item() + valid_pairs += 1 + + current_metrics = { + 'PSNR': f"{psnr_score.item():.4f}", + 'SSIM': f"{ssim_score.item():.4f}", + 'LPIPS': f"{lpips_score.item():.4f}" + } + + pbar.set_postfix(current_metrics) + + except Exception as e: + pbar.write(f"Failed {os.path.basename(a_path)}: {str(e)}") + continue + + if valid_pairs > 0: + avg_psnr = total_psnr / valid_pairs + avg_ssim = total_ssim / valid_pairs + avg_lpips = total_lpips / valid_pairs + + print(f" PSNR: {avg_psnr:.4f}", f" SSIM: {avg_ssim:.4f}", f" LPIPS: {avg_lpips:.4f}") + + return { + 'psnr': avg_psnr, + 'ssim': avg_ssim, + 'lpips': avg_lpips, + 'valid_pairs': valid_pairs + } + else: + print("No valid image pairs for evaluation") + return None + +if __name__ == "__main__": + # Place test data in folder A and ground truth in folder B. When evaluating metrics, do not alter the naming format of source files. + results = calculate_metrics_with_correction( + folder_A="../../../defaultShare/archive/mengyuang/NeurIPS/final_test/ratio50_final_test", + folder_B="../../../defaultShare/archive/mengyuang/SonyA7M4data_latest/groundtruth" + ) + results = calculate_metrics_with_correction( + folder_A="../../../defaultShare/archive/mengyuang/NeurIPS/final_test/ratio100_final_test", + folder_B="../../../defaultShare/archive/mengyuang/SonyA7M4data_latest/groundtruth" + ) + results = calculate_metrics_with_correction( + folder_A="../../../defaultShare/archive/mengyuang/NeurIPS/final_test/ratio200_final_test", + folder_B="../../../defaultShare/archive/mengyuang/SonyA7M4data_latest/groundtruth" + ) \ No newline at end of file From baf402addbeae03bd1f92fbc5f3d0c878ba4aeeb Mon Sep 17 00:00:00 2001 From: mya012 <96360296+mya012@users.noreply.github.com> Date: Mon, 20 Oct 2025 11:35:00 +0800 Subject: [PATCH 5/6] Add files via upload --- ultraled/__init__.py | 12 + ultraled/batch_test.py | 67 ++ ultraled/models/__init__.py | 29 + .../__pycache__/__init__.cpython-38.pyc | Bin 0 -> 1252 bytes .../__pycache__/base_model.cpython-38.pyc | Bin 0 -> 14085 bytes .../__pycache__/lr_scheduler.cpython-38.pyc | Bin 0 -> 6397 bytes .../raw_denoising_model.cpython-38.pyc | Bin 0 -> 24789 bytes ultraled/models/base_model.py | 392 +++++++ ultraled/models/lr_scheduler.py | 147 +++ ultraled/models/raw_denoising_model.py | 970 ++++++++++++++++++ ultraled/ops/__init__.py | 0 .../ops/__pycache__/__init__.cpython-38.pyc | Bin 0 -> 135 bytes ultraled/ops/dcn/__init__.py | 7 + ultraled/ops/dcn/deform_conv.py | 379 +++++++ ultraled/ops/dcn/src/deform_conv_cuda.cpp | 685 +++++++++++++ .../ops/dcn/src/deform_conv_cuda_kernel.cu | 867 ++++++++++++++++ ultraled/ops/dcn/src/deform_conv_ext.cpp | 164 +++ ultraled/ops/fused_act/__init__.py | 3 + ultraled/ops/fused_act/fused_act.py | 95 ++ ultraled/ops/fused_act/src/fused_bias_act.cpp | 26 + .../fused_act/src/fused_bias_act_kernel.cu | 100 ++ ultraled/ops/upfirdn2d/__init__.py | 3 + ultraled/ops/upfirdn2d/src/upfirdn2d.cpp | 24 + .../ops/upfirdn2d/src/upfirdn2d_kernel.cu | 370 +++++++ ultraled/ops/upfirdn2d/upfirdn2d.py | 192 ++++ ultraled/test.py | 45 + ultraled/train.py | 222 ++++ ultraled/utils/__init__.py | 53 + .../utils/__pycache__/__init__.cpython-38.pyc | Bin 0 -> 1395 bytes .../__pycache__/color_util.cpython-38.pyc | Bin 0 -> 7624 bytes .../utils/__pycache__/common.cpython-38.pyc | Bin 0 -> 889 bytes .../utils/__pycache__/diffjpeg.cpython-38.pyc | Bin 0 -> 16162 bytes .../__pycache__/dist_util.cpython-38.pyc | Bin 0 -> 2590 bytes .../__pycache__/file_client.cpython-38.pyc | Bin 0 -> 6489 bytes .../img_process_util.cpython-38.pyc | Bin 0 -> 2738 bytes .../utils/__pycache__/img_util.cpython-38.pyc | Bin 0 -> 6111 bytes .../utils/__pycache__/logger.cpython-38.pyc | Bin 0 -> 6198 bytes .../matlab_functions.cpython-38.pyc | Bin 0 -> 4105 bytes .../utils/__pycache__/misc.cpython-38.pyc | Bin 0 -> 4357 bytes .../utils/__pycache__/options.cpython-38.pyc | Bin 0 -> 5493 bytes .../utils/__pycache__/process.cpython-38.pyc | Bin 0 -> 6705 bytes .../utils/__pycache__/registry.cpython-38.pyc | Bin 0 -> 2811 bytes .../__pycache__/torchinterp1d.cpython-38.pyc | Bin 0 -> 4374 bytes ultraled/utils/color_util.py | 208 ++++ ultraled/utils/common.py | 17 + ultraled/utils/diffjpeg.py | 515 ++++++++++ ultraled/utils/dist_util.py | 82 ++ ultraled/utils/download_util.py | 99 ++ ultraled/utils/file_client.py | 167 +++ ultraled/utils/flow_util.py | 170 +++ ultraled/utils/img_process_util.py | 83 ++ ultraled/utils/img_util.py | 172 ++++ ultraled/utils/lmdb_util.py | 196 ++++ ultraled/utils/logger.py | 217 ++++ ultraled/utils/matlab_functions.py | 178 ++++ ultraled/utils/misc.py | 141 +++ ultraled/utils/options.py | 211 ++++ ultraled/utils/process.py | 223 ++++ ultraled/utils/registry.py | 88 ++ ultraled/utils/torchinterp1d.py | 164 +++ 60 files changed, 7783 insertions(+) create mode 100644 ultraled/__init__.py create mode 100644 ultraled/batch_test.py create mode 100644 ultraled/models/__init__.py create mode 100644 ultraled/models/__pycache__/__init__.cpython-38.pyc create mode 100644 ultraled/models/__pycache__/base_model.cpython-38.pyc create mode 100644 ultraled/models/__pycache__/lr_scheduler.cpython-38.pyc create mode 100644 ultraled/models/__pycache__/raw_denoising_model.cpython-38.pyc create mode 100644 ultraled/models/base_model.py create mode 100644 ultraled/models/lr_scheduler.py create mode 100644 ultraled/models/raw_denoising_model.py create mode 100644 ultraled/ops/__init__.py create mode 100644 ultraled/ops/__pycache__/__init__.cpython-38.pyc create mode 100644 ultraled/ops/dcn/__init__.py create mode 100644 ultraled/ops/dcn/deform_conv.py create mode 100644 ultraled/ops/dcn/src/deform_conv_cuda.cpp create mode 100644 ultraled/ops/dcn/src/deform_conv_cuda_kernel.cu create mode 100644 ultraled/ops/dcn/src/deform_conv_ext.cpp create mode 100644 ultraled/ops/fused_act/__init__.py create mode 100644 ultraled/ops/fused_act/fused_act.py create mode 100644 ultraled/ops/fused_act/src/fused_bias_act.cpp create mode 100644 ultraled/ops/fused_act/src/fused_bias_act_kernel.cu create mode 100644 ultraled/ops/upfirdn2d/__init__.py create mode 100644 ultraled/ops/upfirdn2d/src/upfirdn2d.cpp create mode 100644 ultraled/ops/upfirdn2d/src/upfirdn2d_kernel.cu create mode 100644 ultraled/ops/upfirdn2d/upfirdn2d.py create mode 100644 ultraled/test.py create mode 100644 ultraled/train.py create mode 100644 ultraled/utils/__init__.py create mode 100644 ultraled/utils/__pycache__/__init__.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/color_util.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/common.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/diffjpeg.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/dist_util.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/file_client.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/img_process_util.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/img_util.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/logger.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/matlab_functions.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/misc.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/options.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/process.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/registry.cpython-38.pyc create mode 100644 ultraled/utils/__pycache__/torchinterp1d.cpython-38.pyc create mode 100644 ultraled/utils/color_util.py create mode 100644 ultraled/utils/common.py create mode 100644 ultraled/utils/diffjpeg.py create mode 100644 ultraled/utils/dist_util.py create mode 100644 ultraled/utils/download_util.py create mode 100644 ultraled/utils/file_client.py create mode 100644 ultraled/utils/flow_util.py create mode 100644 ultraled/utils/img_process_util.py create mode 100644 ultraled/utils/img_util.py create mode 100644 ultraled/utils/lmdb_util.py create mode 100644 ultraled/utils/logger.py create mode 100644 ultraled/utils/matlab_functions.py create mode 100644 ultraled/utils/misc.py create mode 100644 ultraled/utils/options.py create mode 100644 ultraled/utils/process.py create mode 100644 ultraled/utils/registry.py create mode 100644 ultraled/utils/torchinterp1d.py diff --git a/ultraled/__init__.py b/ultraled/__init__.py new file mode 100644 index 0000000..2843754 --- /dev/null +++ b/ultraled/__init__.py @@ -0,0 +1,12 @@ +# https://github.com/xinntao/BasicSR +# flake8: noqa +from .archs import * +from .data import * +from .losses import * +from .metrics import * +from .models import * +from .ops import * +from .test import * +from .train import * +from .utils import * +# from .version import __gitsha__, __version__ diff --git a/ultraled/batch_test.py b/ultraled/batch_test.py new file mode 100644 index 0000000..88dc135 --- /dev/null +++ b/ultraled/batch_test.py @@ -0,0 +1,67 @@ +import logging +import torch +from os import path as osp +from glob import glob + +from ultraled.data import build_dataloader, build_dataset +from ultraled.models import build_model +from ultraled.train import init_tb_loggers +from ultraled.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs +from ultraled.utils.options import dict2str, parse_options + + +def test_pipeline(root_path): + # parse options, set distributed setting, set ramdom seed + opt, _ = parse_options(root_path, is_train=False) + opt['root_path'] = root_path + experiments_root = osp.join(root_path, 'results', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, 'visualization') + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # mkdir and initialize loggers + make_exp_dirs(opt) + log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # create test dataset and dataloader + test_loaders = [] + for _, dataset_opt in sorted(opt['datasets'].items()): + test_set = build_dataset(dataset_opt) + test_loader = build_dataloader( + test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") + test_loaders.append(test_loader) + + paths = glob(f"{opt['path']['pretrain_network_g_dir']}/*.pth") + if f"{opt['path']['pretrain_network_g_dir']}/net_g_latest.pth" in paths: + paths.pop(paths.index(f"{opt['path']['pretrain_network_g_dir']}/net_g_latest.pth")) + paths = list(sorted(paths, key=lambda x: int(x[:-4].split('_')[-1]))) + opt['path']['pretrain_network_g_dir'] = None + # create model + model = build_model(opt) + + # initialize wandb and tb loggers + tb_logger = init_tb_loggers(opt) + for load_path in paths: + if load_path.endswith('net_g_latest.pth'): + continue + param_key = opt['path'].get('param_key_g', 'params') + model.load_network(model.net_g, load_path, opt['path'].get('strict_load_g', True), param_key) + current_iter = int(load_path[:-4].split('_')[-1]) + for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info(f'Testing {test_set_name}...') + model.validation(test_loader, current_iter=current_iter, tb_logger=tb_logger, save_img=opt['val']['save_img']) + if tb_logger: + tb_logger.close() + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + test_pipeline(root_path) diff --git a/ultraled/models/__init__.py b/ultraled/models/__init__.py new file mode 100644 index 0000000..4911a3c --- /dev/null +++ b/ultraled/models/__init__.py @@ -0,0 +1,29 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from ultraled.utils import get_root_logger, scandir +from ultraled.utils.registry import MODEL_REGISTRY + +__all__ = ['build_model'] + +# automatically scan and import model modules for registry +# scan all the files under the 'models' folder and collect files ending with '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'ultraled.models.{file_name}') for file_name in model_filenames] + + +def build_model(opt): + """Build model from options. + + Args: + opt (dict): Configuration. It must contain: + model_type (str): Model type. + """ + opt = deepcopy(opt) + model = MODEL_REGISTRY.get(opt['model_type'])(opt) + logger = get_root_logger() + logger.info(f'Model [{model.__class__.__name__}] is created.') + return model diff --git a/ultraled/models/__pycache__/__init__.cpython-38.pyc b/ultraled/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07a33bd7f106efc3e30a9634e7b1767da306352e GIT binary patch literal 1252 zcmZWo&yUy6n7;!82g>P9^0W7C0Pwu+u+O;Y zHVTu@0~?Em>+dRaVf|KjHrAc2jQHNz_&nF5$!h_~y~S-ypZ`kK1P=Z6@v;_8xmo8; zximr>RshQ7!?V-nS{jv0!IuuVUNTlSm0=9-v&kJ#RjyTz+ZX8rlh$i((xPolUNx`J;5nC`nN26sTiPf%{Rown;-7`fo+YpM zY@+WP_?%wRXr@A#<%6voUke%GlfoRdvP= zTy7ioy8rKUqnwYyv}^PzrpGj-@6q`?>V`{?ZR9StUaJ*8zZzm-@^97Grjpu$QCJ!K z{@am-5+!I)ZF=MpW4V-oGlJg3`r9Z!*Vv?p+^K?Pqy9Not>i*w)9V!|1`lXE=^N=< oQd%L0H+W&6D3p+m@4!(|-CgWAjWlCAq9bnx76vnV>`nZ?0WB9lQUCw| literal 0 HcmV?d00001 diff --git a/ultraled/models/__pycache__/base_model.cpython-38.pyc b/ultraled/models/__pycache__/base_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd0f079abe04ae4ae6588300e836214cf9f99011 GIT binary patch literal 14085 zcmbVTU2q%Mb>3YpfW-m?DTH3L_pm1&zbQ>SN{R<@n1 zr)pE2&$r6$>DsieeOeQiC|uV>VaKS=2>rZPvERYFwdIN)7e6P3Cp~ex*$8oK3D4^` zI^C-kBQBk;hxM~{S#PzxR($w$GYDmKr5B>|$OAMgwWJ#~);-Z{c@p>2YhLKePA7C* zowYSCDz@uE=t;NZx2`tm-J?GyZXU;9@b5@Gjb;Hepw}|O5Si;*%@n4{;+hpXk;gTs zzNeNKwkYD>5=&xAl&|Zxf|wRFD6z!>F^g+aJSOJE9e7d_hs0r&Oo@51fNNPC5qIJ` zEsl!2aGepy#ND_a@MeW0K5$*z(Q60AJ>r8Xn-lkn58--;I4SNIi|Fl;sE7wpau~Hg zDUOQ=uj{BkFHVUQ;v;yzAU-7?5)UJH1T}t6Ji`6mDLyJbhE_+#$Hk+#-X%VPwRk)( zJXsIC=R3k{^^25s+LW(03{1!Hhw?N}PvI{(i6qh^&DBDE*NAkXFF;;4GxCVY_=YfV z(laX0E@;SQ)+jYO%?b_g%q_=xzt?uxx;?3*v)F9(L_N+ly1m$H25uPoLI{PVff@Csy$BuKM`y)!1^~rr!)*cLu!#tWX)n}UO^8La<2$Q8zd@O* zZ99XXC*)eNl-yAviBPLk2fR9qtU}dZ*lYz=DzBbKa)uvPmYk8E(?->?&PJ!hS#0_t znmwf!&Y>ydag8?#R|fct*L!}@kt?0L6wUyXmz)cG3dh^iOuEhXnzOjl>9p{^i|byv zj;{$j4lh8nU0?Hps&m>~twWWUoM#|DUfK!t3wy?LCBtw4sKuRoIhn=Utl#k`$|_kU z1$H_Q@)(vZE+k(QTf>1(iiMr_eJ|PswCGyF)@P0WTq+=Gr7^(Bqi!N4Z{a3{oCtl# zkn;q!8~SEW&XY6@%Yz3(V~2oMr|Y(XIvJ4x5St6yu6a{CtzA0xC1X3YZARv9R$%kG zX>4aBTsIB58e%Kk(fu5*){efN3ymlnCZ>`TAEeG))tNi! zY&p)LZk$)8L7ZFhLdaU2TW_vG5c=kWO@B4cAbIDf_`yOiV&dd|NMe)vuH@tVdOfIz zp_CuuQW_qML!3p0b`WO(EOB-fiV{}xYVk|7_yLf~2O`Ef^&xRijUvuBIzHq%T!BD%jAtdO_eWvyaR1b&%X!;@i&gydprTu$H=JY)q z|J#ROJ@n6i|KC6U)R=OTViJt=ynYE8z%4tO+#`ffmVqs*!inQ!RZS^qem8F=<6KjVEcc{UT8A`|Iv zV}HGo+0n1*Q5I{PU5?EsM7i>y<7WJ)7<1Jmzhc*$S9U$|d?LEo26gNs6*^Q|^?f=7{ z3Cp4w84#3`5)?57K`HlV&qEt~T_?3%t1xZJ(2Wa8S!~i0Lj#M?gH1KhhhF#0xp+3Y zIR}zdmm$jLpXvln-+RLMJrFW~Eq!7=3nQoLp-xHVR4=td7s!sDXm;jYEz9Ncfjv+9 z_W7Y##p%xvrMa1kD9iJ5Wm*xcOt;gOvyc#Z07;yyL+X7YNde^-DXCHN5+$^S#@Vif z9id=jnMyBF@;M|Gld+;!;OQaht(B80VkMC0sqO@V%z(4?x~VHD=%P|^9}>+V0I>ig z$M#8Wl-v3}M*m2v2a{cm?Hpsy5ZYV8$@LICeANJ7pNsSj>;ujo{3lb7ekx{cHz5)b5Fc(5Qdh=h5LX5N z%IP3KLl4GhPCkaRSMjGz7PCwqj`?&l#>vOnRM6@ssPJ{%j7`NNO~r-@UBGleUpRjw zvu*hH@Cg|%yE)=S+joC61lg{8}K(l zLs1?9f8cw}L*J~vI7{Ni#s?gSY;Ghz4hPPs6-<3+)u|7C1OKm`w4ZSvbjDzkm>bCl z320y;A*y(Dgx6&8Zm$J`qrW_{46${xRd26|`l%ae_n-LFMhyX07Z=+nc`~XeTb%|d z#3^J{n|@(gJz8@52ahWOtVTAJE(bx0`U z4CW3=uN`OF^=@pZJxW@3m|6a6XRu4LXw`}sA-W7)w5;XnhB7u~1{+sSxd0}=eroIk z-$d^Lp^26+>qiV58f@u@j2!Hv(fw+y%^t~2=-s2}hYV*jXLo2d)vM4Z3PFe={4EubpyBMEVkf) zfFI)wF5EtO2vn#y*14?iJIVf=x-#VF@d5B;#F>6mfoxpj*&Osf6cps<$QnJ6LcyTMn8;^PJ(=kTaI;370PFzsd<6QfgBgG! zT<5ty4LgZ5*Q*kaC7k}WJHh1_^lNafplB75)BS67U&4>3dtii&|jcN zY_!eoH=;krV+Gw7Q7+rC3=0x#=|_$JvD<&l*bLKFN&@>7^+%TA07b~|fX`|}C((nT znag9%-X{*E2#9dZ74$9xN+y^QEyv&CFI|q0r(Zz2ITY8Qxh5Rw!mAh#2nqI$h%CHk zpMOiEpl{Rk{s-U3801-09rF&7(N2au&*wF=;4*9qW0!;alf?oMU;`JsAeq1jMUYdO z-5T&AX5j$IDF+B>?vs)3!x^$hybYeNFpIJmA`@jM%FGa)ZbwII?!ep*!Z1$`=WBXX z+s>1F1m(03zkog0+AaV(3Xug4Xhjw=zHK|QcZ=*hdjK_x+oh-o-iI3UFQOt@WnY8E zxm}D(yHhu{$c~EGufLCLJ}Mx0O%%7MqA5|@G-1cB^jm#cWxLaO3MaU(UDW(#QwwJ{ zZ9IKe*S5>L)~|*Kz%Qqx>D^gTj>?#aY4k8Puf44)h8Z5*ofA_8Zx^+XYQo+wzfyit z`=aLCFQVkT)?5143O%7#k80afFKS!=!eby?@)hoPCd#4T12+^;wIRE+^~bRg=SO2y zw?#+Ls1cWr6FIa+y_EQPE`eSwD;l^MRLJAB;&HG7l$BB2*2rpo_=L!^WbWmw9R*vdmXiRVuOK6DBNtow; zZX4X6zL7V)TCN2+48mA<`)0KpuE%D#j+E>luNCwkq1Uz$3J%u;V*WIh6b?*>8mxDE zEkRR8d?WxDuD+XB)>HKz+CrD#Av|JD<_!TxJ*^zSz(Jg_BbS_a7mg$L(S}{mToK+b z(3fW00|kF~_BGLaz2hrFp9k{O$l-~=pyO=N!ikK^3>)N%;Ins`Nr7*9 z5mc7kxv@#YS5uaRi^X^48#Is%vhWE|7&^{MnzmT?WQRsOj>&O}yB)~MKcGTD^j<4k zh@;Ht^N6R+AvNLiC!1jgDO}RJ&7_HH%)-1qg3X9r)%F;P8twSPm>ZFBGr>+AgkiHz zS*^rxfmksAkyCL1b0j8^4Nc$*a4^c^gv4(~IRxXF0f#x*yC|K!fRj?02V$q2%-d&92HBDyr_z=rp>X$ApIfD;}2cq;C_ex)c zyL?qm@N%Wdur+`ISbP%a)-;AA`7Pw)91ah_BQC!|Hzv(GaA~~`r}~#sQn8h{Fh6(^ zMUK?8p&`pftu%NgcP&q}EMdo!7HaIiP&4N4&NXTvG<+D)q9J%L37AFI;yKtegf}5UYIwK1ViDG3^{^*+^VtmkA$93R=>gcSo4hN#Ard@AD@03z&+^+h9H z?)ctu2Z1^UH9FepfCPo>BNT|L(a&vI(wKXUED21u@>7y$BO3ZBS?Cdh$Ok|woUj4g z=rkrYL9U`j74793GX2F=mU&TzV7A9e3HM70ene&;Xb+U|*}U)1GjA(mXFYFZGQ}Rk?#<%IipQ)Zlw`nNs*d zd`ZRQOd7fL>3*G(CM8Vwh(gIXDIvI#-$xQp^LM5^t5QfAQe5DNv_SkW`K#27nap=7 z$H+o}Gcl<+$ciBhJ-i>l>;}hBODpDpA+x}Z8Ff$up^7<_%rLDP^#P4GA9MduR5oER z^XQl$ViQb(f)$am2H%A+zGHH9i`4oMpn(q);dq)!P9WwY$H*Gwr&&CuV@3#qCQnZ) zpZpEB5jV2%>1!K#kwZv458IIrALWHbj6@XRc=}0y9X6Q-t1kn1P&U~a*ktymEq?@y z3}=gBk&(xrEWr{YYQ#%QXl;*JWK+9kVbj>yBAePSyi#C`%wmfyxAhaY$Iu4$7|OQ# z{C2X(=D61~F)cAYu*WD)H`D*~^F+(X&j{jJQxDMu9A%<-9Wp}?qaW$06uvSzr@~@7Nkb42@1w^6{E9DwpQ^xhhMqGBc)!9|mVV z8By|U)Ryo;**SOaGbuxZM^e~lS-wh@Uqga;VCwQDYR|4H`Ad}hDkZ;6Ntu#qO2`b6 zZ&5-f7Vr4-S19*wB;W?*U{l{h+JK#(Mi}dzSP$f{(X(Ht@j zev%Yy5iRJ99_#~b)9P=YomFHh%nU3ATUiPtr(yT$Zwy}PQwUXgqy8o`bnE~dENldd zm8-}D{39HbU=In7^9ep!fHGY939j&+oW>J)QLd4n1(7gl;QSy`LGcXIL&PjtL(iZF z+9Wkd2XVBNYw$L|>C|k?ZDtJeOFSHz9H6|^7pCVKcHkns-_1$45#MjNutQhrsas{9lb-H za#&@e%^s;qaQYw7PY8^?slEAUpi9F8NpHX%55xvmv$ct}K01&F1MYW|>xz1Y8n|MuuL;ECx3>a^&H zKt1QfiXg(p|K^d191IG$YA$0zao(3jmKq~l2U}lD4G-tjXTK0QBmM2i84>CvxpWps z|MP?_t7JG}QT!tvtblUc>TyQE?_$cjzZu&gmM!6u9Wv&4CbB>}e;$AAuWn+7FO{)6 zbSeh>?6)}tl}Zs}LF{<`Ei`~<2O~(IDgU~ClE%s*sB0u#6vN4>AXGKwTxv!^ss`US z5)w$7FIP&6=j8nkVosnsjKC+*A%_MSR~wsgabyAL>?^I6axD>K#34ej?(^Z=_CTu; zVI01f?Y0=KdZgc+tqxh^FkB^;1>g>t`0 z35ENaR?6R{+z%=FBTD{~l62NM$5A4(0jKa6EFpp4%dqlRhJQ`FY)xCXUBKOpRj>|Q zmZ94RE!(o}!&b)7OI7=W_ED7GV@*ryr*c5;&AjK*jvT`d#{XRLe;5Q(Ej#^B7PH~n z0rp>1VK)&Or~l0wIZZG<{Fj=la_aEupSYZq@_!)kK3|(3FRYT>x%6D}WYNg>`*a+V XP)l|S!lJ@VgtcsL>oeG=@tgTySXK^; literal 0 HcmV?d00001 diff --git a/ultraled/models/__pycache__/lr_scheduler.cpython-38.pyc b/ultraled/models/__pycache__/lr_scheduler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e13d557322d7279062ea3cb9041d0066ce7c6a53 GIT binary patch literal 6397 zcmcgwOK%)m6|P%V{Y=|&-gXkr)Gz_k$Z0=9W`apFgk%^7g=pktK)!VvIK2mJmEBkH%P@fYgV^w8Ru8Mz@js5X3P1c7W%6ea9=X! zTMNae9MA0Plq`mugD|r7-u%kGt>zbwyS5$`z-K39fO>jWadL4K(JzPZ-onTw zo)t!>vl(FjdDkbrJ`Ek)aD8J`Kt@2kkRtcj-~W!vA3oR-68u$m#0&a+@e-v|K`SfN zCVA5Eg||^f%0L~8fnth*hFTrysI@@}wH~3~(1wEl^`W@0hL@wVEr#M}BC15yfihH| z)dpfve$FZbnNx4~&*xOOd>b5fTxn#~u>->oBBN_X5SeCOs(>J_7cvo-n6Gg)J0m`v zwlc!Un&Y2>te(U z=R9eWy08YfHXq5OIQba9kS;Wh(=@LtM)WwR2u=ptG%P&?oMf z<$L&7sr^i&(dS}Dc`CL9tbwrACsmB;8>O4d5GIG_@s;ma8jUq&psY+i76TZg860?A`L+5$hX9DUPSkfT&-A3Er8{b@w&uZE8T2+W zUrvO9=j03)#Pus=2|7V{^V(bJhjdd}QI-CMF{D~^BAnd1;sqVc3$LM(?Hvp89Xjwh zDu%d`L7YGfcB?6r@(k7_FcHuUn0y;8178l&zAMU@Wukly4Zh;7?244`0{`j^7G)%o zbYfFWq>$zD1P1tiXfGXBkcJ_62SYgo&5GpZh4_|S-*z5K!ZB`aTGHy8OzJQZ_OE0I zHr3>CB5^|ZazabOJQdT1I4(|#hBzwvGY8VXsY^mzk}Z){s_3!ZsgUORR@X7jxNe%z z3%q7%H%#+E&+^ilifP(G$23_%nEYIwq_asvSh4BOkgBTC8{{kal2V!F2HKZP#(c9IZG8;2TU=uiE#<5TGV3hGUkOG(((m-jK0E# zq<r9td{UgvrN(uGofwNYss4+Z2-UgfyPTUo6ZWRlbYXILZGO#&gN-5LSk0U{JZOE&Nh!FMR!Wz;;lUh#HsK2M7~_PGKZboJDGlSDcG9b&%rx0Fx=f=?$J1{S%|e9jrbY zHqsZ~M8%SvsLvwW2dEEbv<-ca+T z33{&X&{<@s_yPU;E=>Q=YIAoK6j(cEP_8zwp(>xy0fnl7bH~Qu^GbCg!MRFD<&Q^M z(pZPQnye7Xh}k#rg`|0KjM)ONO#eiEtUiW$A87Z5tPWIF5g_w}D#vrG^*8uwA2|_u zy2fS@!uJV$V^s!S!!qEMx`Azwi;{*>AE)t+bfTX0c@w|iMzG2q=_0(i(D8y0&Y|J= zx}PD&LIA+0kDf+R#Pobl%o$M|j-7obidW9D9i#t_MBOhG2f~ zuw4!j79o@6uSM&veK91(QS1Vc%g8Rk>Bl0XC>_ByTmW%O`8guUp-4k)l~X*jhSg0q z*1}#luB5=EUbTE5b0Z!oJD=no`~qL$9b7ud4&*-3xGZ(k}3Mjs!(6*D!k`+UWvp$ zf@g1i_#>PQLTeiWpE^|Gd1)KkfRiai`1`6nXA8eUv*@eS2xXqC=poEQizBBkbxT{- z5hLkHyJ%U`1rkzlX&(UO-ZRI)iw1LiiLZv2X(LZCATX3^v}_T<(~RI%(xuw~U|TtZ z@#m0<7uUT9F#wOkVr*|2^#Cjv#rM(3`VSOx=i0J?5cc*+$!=YJ9>MvlRUra z<5Tn#gV#QcSAyi*Ooe=0p!O?V&YoIXQ~z~boPF89cpzI2_ClLwe(4fHVBsP0VnK_m zDZn8xAkxK}Yd=h0R3c&#Up*yIXrF@UsxD58x;mxl|Lnj03V@K2;ao;;W)6I9zWbU3BWDbDtEk|gIhp*5N+lw`L!m60R{$S+NZz$#PZ fLUJ!Zu3kw}E7u58#Jg5Wr>pPv6abS-os$P*>K;3 zNKOY$EXbCZP*zXdSh2zeu;j#$(=<-%x^>%hOWh=GeR7qc6FPii%wGY z_xtXf0R{kNDov`U)!@AQ?z`K&%eQ~u!{dX4X#>AcfA$Y&i$87{zsra9zZf1);_{z2 z4Z}5Dvt|@{w`xY+ESM(qcFn5W1zYk?&8f!Z65GDI2Jb)wdP4A)j#vYvc9p zh3!&4gtj{hJL)?NJL|g&yQFNmw!6Nkut)OQ+70!+g}stDYd02dl=bB6re&B${ieb` z)Qq^fXN|(m?j7!^JNB$u*zX>7x4GlUy~Ew^?m*ow?oM|X?zg(T-95M;aBpz;;y&Tt z=;m=h=-%Y+!~Hh*W_LgChun9#x8Q!ed#5|$9>hqK_Z#kQ?xAN5_mF4K9bU2vN4#yz zcdQsFzulcg`J}gf$zC!GM|k(c8t%iM;*0o-sbM%(KQ*nMJE(bgXw7i^jau zF&;Bk&8bxbGqEx0Fy7#Jei)OxVlPQvyXSD-Yg8^Rlp2-8f#(OsG8S<7p?l9BUZ@4C zRP)@!buP(2jCmJbkNvBSN>K_YTbIHSf1zH)D~hECcAIb6&UwX=7-rH;n`$?5eG70< zzryrS}@SNRL?xjKNStG~peonL*< z|0iFl-m{trWBPDnGu6gSGfd3j^nyE&hPK!8!$iGwv07iKhj9%BVeG;zfN;VI+J*=1QU2K&aezn=yk0CZL{5F)8i9+390+}9o{1rUl zXiXP~{xJ*W%;=^w&Q4Yb-Jaxh16jaQ0KU@E$VEki@EVF>Kp)(qqlA`25>Yx zW+&FME4G_(lb5ZRBYSpTqpT95*Vwj&U)k=-Bmk*lx5*~u`Y`!^@Kh1UJ^|~gnV*Rhn z&P-Tg%%3f_ywG{N>Rp&fs5|lc(7rH@)t?T{sn9$Vn#Is*E(HF@B@dMF;!DL!soL<5 zp_`^ER@NLh)oDEUt&zLmTjR&^&@q-Hz#$N<#1;uiaXIoWQ`XGH0aT<%8(OS?mNsT! z&XM4nz%vaNR>z9ohqt*l-iHzbN_$^-|H{o;Q{^Kf0CdRLOJqHszYdht?U=fnWO4|} zl)8)ABS=8Y6D?2G7XoZFo=BR;;vt%_jS1Uc&9? z&|0|6cQ8DD-@Rv2z+KKv9Y%>hvbZ9S>b2&PgGkYM+Ec!#_uwdsuqhL4^Ij(JV;dW+ zwh$w~t!&(etn#FLvAZ2yyq`hB?Z}!#)*cgQ_z0fMJhSeP+re$&S^CpR0OAo%ow0X%#csh??6;PKyeJDS5D1YWq{((dJC#Cqpw3Hr`2b{P(-FYZ~rxZ+1X+$7(>*Efj zv|Pz3E?+nmpz*1(Me`Hj4df0TlPmg5*1YqWp|+uw>>=ooHF+d7tE;AF(Gjr`?wmeR ztNKB?S#RC_100q+7YA;%4}o(v{RQO}XBNo6iZiNNFXCh-Pt=-_ApN_M>2{rnhpB~z zKfmC4ZBL*w#342yND}GfOu|G|jkb`>{Ix?bWTq(rLKtm8$|%I^o~n3mA+7Ixr3Qt^ zIFw9!$F=76IF}$G5zj*Y%9wd`hp9e<(lV~zkDJ(|mmNH`_Yfqg^(;bLK;A~)T7(Al zq|w;#7=e90txlS#lUfJ0&Z1F|vBigshV?<<*<5@sF_(nQMv#HbLE4Br1k&YHkcMn# znE<`z0kT+~?-(p4YPdk@w&c6sC3lInq{LuOuDo2?=RC=8dI4L?u|m_e&lRhTSfS=qkOX+YfHda}Jv zIHl{H?7A)!nb6^KQ~`a}hmnK>57v3&_bLUs_#-TS3P~Y}-E7uC=L#9BRF8I_M$g`R z->JtRI#+!7)FY2PICX#V+^PEu!zI^Uzm3rXS-Sq^v328TighqVYB52nqJEI?C1z|~ zbY}e_uz@IBM%)5Qfzq1eR@zG1NjrmQHG}#-Iz`!#BK{e6)J$m3Tn3+9G1bXHrGIPVY# zEH}OcV%9+`@I)|uHvvJ}o^xI@++@c@n(Ej{(^nE~7ohKwxg1|Hpno6why_79-ia+I z+yR{NTqllkVld!9H`!^Nz^Kc~PHYAJ=I1bK3bmFyco_syPI0Jo7eGc1GO}50d3f9s z;4IE8I>d9c-7}oYM;(2N=F?uDZ9P{X5NIW<^9LG@N!TbBY90V!N*hPkTJY@4Xf)ky z)=>ZCtQSDl6_6+gwM&~ulrt!v_b&C%1_$F&NM!|713i!K>?ZRM&g6YeC1?tL=FIq; zYBsz>`2`<}1(rlJO#W6b^wzcMvDsN~rUcW(aaIxlqUuegkoYHrn-EBQNS`c_@S*TS z1q*HP@r}TjJW(svr`^)spGDW-$7Qtl%D#6Grbm#w4FSU$&=deHv^>ZOQ>xBhhabWT z01b}=%Q)AfS1;w=CN!xAR_9%WVeY|NZJ}OmP#NDhmszI^13QMH#)n-;r^)xE<@+qsSXOa9Qt}cJyy?G{+-BC8FGChj+ zD;S|;&OtN-pRmr`>T3Zx3K%@p5o=xfMevuX9cA^WPKp5LKt_kW?poVvL$N6(Qw-}b zS%CqqzwCrDe{BBi8%8{G?Gd*e;p2@Agb~gFT4AI-IqNjTNoYYLCFilmFb(;+=@)C& z^PYNw1&K&!Myrp@GXyTeB;dWdpvqnthxwqj5GYy`H-cRPLga;#a>)?nSPd}mG_IU^ z5(@7TD`U2IZE+w`x4!eBSWUtC49B{Ng!}vvCo-eN^?8urtiXUth4Q@R9|-JcXfELD z%{iC`z=smt37#SNOcJNi&8ILzyp!ytuuJJEpG1D3lkN=c#@k?)Npunb=k>cs#CN@? zVM%~}Wo;W{!1sD996IGg-y~Vmg4QQ6ZS_17*r%GoL`Jq$eTt=@W^y-^dsuHi5SmXa zVrn7I9WRy(Nxk94GKc|8qBbnT8}?g$679c=%O69+y&e*Q3EUey&2-D%?w)W$T|loe zU8IQ_qz&XMQ!LIelxh*lA}*=RO!%U3Fd}PZuT}#=TeEyT$0mY(#L_R|67x~g%s9zG z!Q3*+`hG-w{`@wQ>me>$y@8;ga?>C-;QVd|_oO@M4&k2i((bUEeHPksVIXq6WW?=a z6b9j3p}WN_3`OpiVIfsnob8@){893ehfA${{Q%}2Fe@-^uhz%q$k<35gk0?azgr_$ zqzRBuKm|xEmSCqTQU!yYIR_m-dVwGw0H7Y^U)&+YiVS zjAP3gtQjU{8nOK|IF5|*FM=Uy zIS|ctxDj{JsyHmA8F?#q6#b227G+uW2f@e^TjZj#hNIe1JdeF#sxi!ATbF20WT(PZ z#CBj(c^G?kZq*Sb2rdrzJT7?E%4|>=@k>xIg@JdAhI>l3ZlfMR=P)6D4h+R32y;LU zj8koNolV6`IM8E?Xhm5d9MqgyWLh*x!ZjpvTxdUCs)b|Kh&2&SyKaK}!XfRO0)Hve z>OL{9mN1t3F(!fsWczUHz;9Ivhecw5n+f(|(t>bM7=L$GYLsnzW(Ini6frckn=Nql zXh~ipwBZI5=Bu(&jhPCnX>ON_5EcCMWmYGveoM5);K9; z!0xXPTb@7-tXWWx2_wkC@GVv_b6JdlE3kDH9h5t+ecn+&9mJO4h=PGP?n>!P#ub;LvwcHe(X~c#w1!)(c zST0=>#qhP;9*1myFAQo8m`bXpTK?hY(*ThCDQNl-#kGVh?T_cx=WuRSKY`?2 znAUPK#b}6~)r(%ONZ1{MC|34b0eFNwCy6V}^|5Y}NwRLDfltV$eS~ZnOEZ_CNQ;87 zYAH&7>Ptvg?a;iqW};IcW5pWCw(y7(GsJS!zULaixHVF&w;yU?>oJd3m($ivvCn2OFOlji7lW{yyG)2w?xlL7zm_#@vncgvl8JJaJ zds0m}Km0KLM5(smiF6WXK{pm!JiTfizAoi=OZ@VzM}a{3tD2&ODNS_=0p~20J8&(6tOEpMCMO6z7Y|hZ}!?GUFo%M+yl|LHq?b( z5J_RGfLY{~wKw1$FV4R*|6%~c#ggeJ&O81$gY+{LV?o$JBF?TGJ;^zBER6N&Vk)5U zV0BVU7Fy*oVkQ_2hG0rbk&J>s4X=Pq$tu3v8zbwcS?@~Uxe??tSG4iuyC8srAgwvD zd~-34u}2{sJ2Oy(ST5oU*#F!%5qHKoE;YB+;VFUeFk6GmzEOg?y;!5q(hrB$N}$F1 z?ZhLZt>@d@L6d4VQP!tnRf|B|vJDG%ZQp$sta~n_sBH^a_ck zkg_-E|+ zD@^`5duR0GG&fK`#PT0z@=Hhxv6|Q-VzQ&rU`1}iMuPc^XpImo!#{xx=>oMA2eyYC zXhX(|!@!$FE{$heCejo;gl;f5WbQ?ay`UuHX8YD_P>7ABtWSI=J>?~|VGn782FY`;&_iWhR|2JyF1d9jh?1e$zOE#W5(*=0BSZak6Kf@K`X$9bLxEkC z1l$Lf9G9vm#Q8ahFHk;>B9!>K#PT4-d~hiP}< zvgKyRqz+L#R7i&51u`||39<;Xh%cdq?hQp3rU{HZBG@}P1uJqRC_(2>*5+Z!@=6Wy zmI{BSqYaXC7?;}}h^7cC$A9F=?>v9$-n+G57W+&4m<{t!_4W54pHqJiMZ(O~Z!`Bh zOy0-jy-a==NtlxHiltgj+tWV78hX;yPciv66S~_(El{h>G1@7#YV**$=JarX~F5qki> zyUZQ-E&%74)y`i7D0>6-L1PNz#c&BUzMIh46OWP8;<1=1AST;TED^g0(b)2jMS17E zrOx6+abk0z(>S+y1In3xMwx~NsGKl@!%K*ZP?25%bQQa$r1fj112D-uK?Yon5ILxR z4t=RsvhM)klIHa|!rhR<%pQk<-IyCve}q<4CZkq>tpvCkgm4Jo-$IXs^hlsb0zHya z{zoWJN;$Zmn-o($c#}HsJe98+>A?~k`f*T9M6a2X%4Tm686pWI%iCBafs-hoUf?eJWVJ{y0-`p{{vnfpgd|K(m&)g95rerh@LFLO zrk#Z{*dYW%?ZQO#VnuRV3j781T1^Q&hneyMV%Xs8gqKsn*V%Bri;7N~0u~T711Phw z5Jv;4|IC;E7bd^YL?pAG1RaicDg8P=^j?DZJJDDqH8Tk*ENkVgaZqE&j6*idk{W~d zj`8081Eo$|v@f*}UIXm6cvYXMMp8^95CjM#1ejx#1gxby4FD2iBS1E-ad?todp(38 zJ9*jx*SBkfOu~W^)jfbZd5Y?Y6cFI3)d7Gkdh!H7F%qn=K&Fbr5T(mR2Jr_z9#mn35PX8|mE})H-bHv7pGE0# z-zd`p_F1wAARf>G_joV@(rmh^^AzC@VnkYIqPGZ|%QMPy`VGhpqU95B^lo{x=#R){ z)<^ci_G2aPh zBaO(kscNIjPctX_*B6=lZ%oL0L|9$bOe8dCL-PV0Iq?0>>e~hNIvWx3He!{-INo`p zktjT{nCvn4Kv3BWaRrJF{)R+!(Rqq8Svv;=+5;5YZ|*nSyRU&l9Jfp+(fu(W$0azk z4;h?5*E?aUy*PcSIj{l(j%nMFy|<@A2T?KSW#6G~bJbz4(QLSo9E;E#t6*~2#?{Ew zpC-?v;|PTFf3j^Vm$sbsgFT4q^!t&(e9!wm>gB zke3*bhRuS;Gzl^j)pRKR*)%R=q%;gsA6y%W5&e%D^B>f2pBR`5cNI6rw9{@V0h;%z z6-+tYdCdmr`WB1b605S9A7}!=Nn)`KdReumWGwi&=>EX%Z{xOM|5XcRFJlyK8aaK9 z{l4aGHd{v!+7B>*Jne@$S&lL0apa&cJjS~ORr)UVH}YM;8g0@`QT8$ zTeY=hhL0hj%b;ZA?D1WN@ng86I7yfaY$yXz1>(SwBxEe82WI=on_X2O>TJh!XdMuE zngeLkkwkVQWS2RfJz!}C4Cs6c(#~2Lqgm-fimO>jU&0RjAJDFdZYzO7hZX}d{sY$F zRucUgz5XH~3Q<{T8|&5w3CLNHlhlg=jQpN@AT{%)mUy-UXb3Y?*V%gpGcxGqyPLf;s=GPT~bi z?dqhKZ*sTEex;W8x#OL|u93Ops8?lV%l{Vcxfd+|gzV4u-mVcgcI-4TP6Ow0Gg_JM zPMpXtymyczxw|_S(mhukI3~9>=A;E6%)LQ+{miCbv8Y!Bb!@k;|G_tG4;Wr+8|=rK zZd*c_%RE=X*}vAEJx+YQk?BGJp8$K~b7O$Mz1SmLeM9ESmD%cnP4C=lW$nI|+3JN& zy|!N2?#_sN<4X>9C@wp6^bPyz6>jK*m5z5;8oXo4c*ewv6PWq8aJILI77>5JQ8Sof zY*}#gsp$x293Sur&3jBBWx(CG-MB%ZrE$x;^0Y3~C4k6VR|JPUW7K_?4?uq=1R}+$ zTV%K|*zGZT;nfG$SpVmNT=oqVDBG5K14JtU$KGChUpXZ;Y--TYz0uz6oYe$guV;6#m+Pr13WBxqmHr7r@!=O7tN6M>>pac^LO`@E znSf%?@EBMlWD6B(y@-pjq}MA)m1u@`5oT+r@{a{u7xkV{DqmB>P3l`e$AddHnVy#Ohn9bjSXtnNRFotAPH}j zkQF##_&Lcm-W+`ehe%2}b;PXVXcQ$L^=nMX6XDT;&mYbjd=Y0i{TAde2+aXZe^eErkPKy%k)7a03T` zj#Wy+)F_ckiW-|b%eL`_7FO+r!w?eSb@JhGfW5Za2%WiRwV}vs!_lc`fRNLgCsEe! zy);UOh>~dfh)rbFE@UV@%ockR-axT)orK!mGRUk8p1-kYS8f)Y&g2VGt@3^9lozvzLw4t?U;oHy04OWmnycg(O}h_MZZ5MHM|#BHGnL z+~+4nl|sA%y&U53_wm*5KvF=+C8CUPV|K7BEemuO61vaEk8TvcKs8>)<^N4A!hobR z2EVHVtzj?p1BX6X6Gl>|us=a(fHKZc*z^S@96fTTV>@WYKlX`RuuF;YlxEZ%`-ovR zi#n?C3=qkgL$Ake8eezdo-@brO-cq@$q=0}ef4>%-=pgxnb8=)x6_$3X6>@t6W2{V z%0#a2FNsXgAWJg}J`Ew$w65SIDG*li7DGUHAs#T&lbSmnEigd2v6q}H4oyh>5=QiM zg?fn|r4ZkNL6)LE+lh%3)`F-sEzEM@N=yt;P#Soj9}k?!or4%t5ejVe3+v{vFPei5 zk90<6Imo$Op)t%Pp`ZVpXj?;5?EyH~BAp7JfpptRXbJPf&;BtA ziQvZv*o;U5%MlQZRxQG57!!mNWBvw?q27=BonS&b5?i2R9ME(O&N>?E0i9<1*fmc7 z4YBv!eSQ8{f<+G$#_-a>z;VF(t`t9y(cU=R2^Q4j;tct^3?A#+kO$c!L4VzrEu5%D z@|rDTkaQ^}h|f=t*&7jhCAmFJKOEn7(Y9l$-~1l$j7y)C|C{JRPe5AwMx~HxvwTah z9BaTcy$ZS=--tK$p7~CE{j{2xI!9xr=wsS2aui#o4I}SD4&EE8&1!@Ri9Dm!V8WkW zXP^8M>xA}yn>k@KcDc0>`FIO9KKar_G>H&DZIcwvyvCl-;#CI-6#FoO0m9p@hJ`xLhlX#u4y6Mf?f3i1=rHEcDxr4Jrlq|9EgAg?lwLVqxBCP$dAP?Lkw%c zL;y<&m3q_=JMkRFLPl1R_>?vky@FvSUP%2MaucS+%D@yt%{I&k0f*y{U$kqR5$#Li zsFZg-jzA>v!AQx4=dIPmhHHkz!F&4#<1efyC%AE>d^ES_VO+hY^@pt|taJCRj%YvL zk8-rnGI^E>z3l22nTP}cy-$k)ijiqy60UwI>%k>CVc{`?%!-Gn=9s+5zH3MO93OWv zVYJiNaru10=N5aD2a@}f+mp8Pd{T^YIxyTdqF!uDEdI_oe zaU^T8fNo(Z!lt7BsA65j(O+ZpI+LGZvdm)*WCvxn>P92U{C-@2ykY{Mb`~q`7 z$z+ttPca!|@ei_Zc zN#D=ldOH&AEsX?Izl<|b-(d1vOx9!$0T0SBKT7a8fz1C_SkO(I0Obt>f(Rb8X9$Q; zhXPxMp>dh4^Z&Mhpaf!+83eR0g}{(_!(g95YO-2yMRf9}U>}4a4a19_73nz3wu^~1 zPc=g~+~jsc!jLF9Ak?wwPOJO+UTUm+kr+Ra*%{4EIs^XJ3YDGWk6|{&yzdX7T`%iAd)yvfwI{f6L_G zG5PmQ{sR-~vjGw*NPmrxxF4C#(LTFiRbix08WtM(aG~V1cWr{Wx1)UkIDdB0KI%X5 z@cxj=A2Zo-g4BY($rF4hD*uE5E&Df}>D!?!P^ERC$r^$9|1*?DL1#0Rg#D+P+W>#V zDEl8hQRmG^f^6Fadw&`+APO|zP&p>Ft!yExqySZNs7FIVf*>Kzx$8W>VX>t6-rd@jz0T&AYMAA?%LA26vH&(!UUeK)h=c$$ZO#AMWm-n!25vG06m%q3&Q#TtB@?2gR=xB+x+}V@@2> z!bsl9M|y|Wq8k*yH&7=z)LqOSXL36q>7i67m^{K{1G}TA`Y5Ni$x{t~;GYXm_089D z#{aMTsMR?h?T46rc#A`RKgvGOL+0T#93h8L%ndHz$xYk%wzKz7+1aa};@K3LOmA^Y zbmM)Qr?gK_>3jNMH!HPoq4!m}T#?6Qc8j&qJNZw!HX7)9Zqm)FQP)AL&i4-@jul49CLF;3k*(q1NF%z* z*BxQxeLR+Lj6Iwe6d}~_DZ7=4+*gf z*G50e})hE?}S_&LdxlwR^t3yV@#6c z@rKD0_>2U-?#2WRpYa0t*E7*yQo^M8G{zDr=1lQ#F05oMolGZ^sboBv!Ier50X|03 RJ99VTd-WeO_~*pU{{;a8+x7qe literal 0 HcmV?d00001 diff --git a/ultraled/models/base_model.py b/ultraled/models/base_model.py new file mode 100644 index 0000000..4bba8d3 --- /dev/null +++ b/ultraled/models/base_model.py @@ -0,0 +1,392 @@ +import os +import time +import torch +from collections import OrderedDict +from copy import deepcopy +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from ultraled.models import lr_scheduler as lr_scheduler +from ultraled.utils import get_root_logger +from ultraled.utils.dist_util import master_only + + +class BaseModel(): + """Base model.""" + + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def save(self, epoch, current_iter): + """Save networks and training state.""" + pass + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """Validation function. + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def _initialize_best_metric_results(self, dataset_name): + """Initialize the best metric results dict for recording the best metric value and iteration.""" + if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results: + return + elif not hasattr(self, 'best_metric_results'): + self.best_metric_results = dict() + + # add a dataset record + record = dict() + for metric, content in self.opt['val']['metrics'].items(): + better = content.get('better', 'higher') + init_val = float('-inf') if better == 'higher' else float('inf') + record[metric] = dict(better=better, val=init_val, iter=-1) + self.best_metric_results[dataset_name] = record + + def _update_best_metric_result(self, dataset_name, metric, val, current_iter): + if self.best_metric_results[dataset_name][metric]['better'] == 'higher': + if val >= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + else: + if val <= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + + def model_ema(self, decay=0.999): + net_g = self.get_bare_model(self.net_g) + + net_g_params = dict(net_g.named_parameters()) + net_g_ema_params = dict(self.net_g_ema.named_parameters()) + + for k in net_g_ema_params.keys(): + net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay) + + def get_current_log(self): + return self.log_dict + + def model_to_device(self, net): + """Model to device. It also warps models with DistributedDataParallel + or DataParallel. + + Args: + net (nn.Module) + """ + net = net.to(self.device) + if self.opt['dist']: + find_unused_parameters = self.opt.get('find_unused_parameters', False) + net = DistributedDataParallel( + net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) + elif self.opt['num_gpu'] > 1: + net = DataParallel(net) + return net + + def get_optimizer(self, optim_type, params, lr, **kwargs): + if optim_type == 'Adam': + optimizer = torch.optim.Adam(params, lr, **kwargs) + elif optim_type == 'AdamW': + optimizer = torch.optim.AdamW(params, lr, **kwargs) + else: + raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.') + return optimizer + + def setup_schedulers(self): + """Set up schedulers.""" + train_opt = self.opt['train'] + scheduler_type = train_opt['scheduler'].pop('type') + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'HandieLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.HandieLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'HandieStepLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.HandieStepLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'TorchCosineAnnealingLR': + print('Torch', 'CosineAnnealingLR') + for optimizer in self.optimizers: + self.schedulers.append(torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **train_opt['scheduler'])) + else: + raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with + DistributedDataParallel or DataParallel. + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + return net + + @master_only + def print_network(self, net): + """Print the str and parameter number of a network. + + Args: + net (nn.Module) + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}' + else: + net_cls_str = f'{net.__class__.__name__}' + + net = self.get_bare_model(net) + net_str = str(net) + net_params = sum(map(lambda x: x.numel(), net.parameters())) + + logger = get_root_logger() + logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}') + logger.info(net_str) + + def _set_lr(self, lr_groups_l): + """Set learning rate for warmup. + + Args: + lr_groups_l (list): List for lr_groups, each for an optimizer. + """ + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler. + """ + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, current_iter, warmup_iter=-1): + """Update learning rate. + + Args: + current_iter (int): Current iteration. + warmup_iter (int): Warmup iter numbers. -1 for no warmup. + Default: -1. + """ + if current_iter > 1: + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if current_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + # currently only support linearly warm up + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + + @master_only + def save_network(self, net, net_label, current_iter, param_key='params'): + """Save networks. + + Args: + net (nn.Module | list[nn.Module]): Network(s) to be saved. + net_label (str): Network label. + current_iter (int): Current iter number. + param_key (str | list[str]): The parameter key(s) to save network. + Default: 'params'. + """ + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + save_path = os.path.join(self.opt['path']['models'], save_filename) + + net = net if isinstance(net, list) else [net] + param_key = param_key if isinstance(param_key, list) else [param_key] + assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.' + + save_dict = {} + for net_, param_key_ in zip(net, param_key): + net_ = self.get_bare_model(net_) + state_dict = net_.state_dict() + for key, param in state_dict.items(): + if key.startswith('module.'): # remove unnecessary 'module.' + key = key[7:] + state_dict[key] = param.cpu() + save_dict[param_key_] = state_dict + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(save_dict, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def _print_different_keys_loading(self, crt_net, load_net, strict=True): + """Print keys with different name or different size when loading models. + + 1. Print keys with different names. + 2. If strict=False, print the same key but with different tensor size. + It also ignore these keys with different sizes (not load). + + Args: + crt_net (torch model): Current network. + load_net (dict): Loaded network. + strict (bool): Whether strictly loaded. Default: True. + """ + crt_net = self.get_bare_model(crt_net) + crt_net = crt_net.state_dict() + crt_net_keys = set(crt_net.keys()) + load_net_keys = set(load_net.keys()) + + logger = get_root_logger() + if crt_net_keys != load_net_keys: + logger.warning('Current net - loaded net:') + for v in sorted(list(crt_net_keys - load_net_keys)): + logger.warning(f' {v}') + logger.warning('Loaded net - current net:') + for v in sorted(list(load_net_keys - crt_net_keys)): + logger.warning(f' {v}') + + # check the size for the same keys + if not strict: + common_keys = crt_net_keys & load_net_keys + for k in common_keys: + if crt_net[k].size() != load_net[k].size(): + logger.warning(f'Size different, ignore [{k}]: crt_net: ' + f'{crt_net[k].shape}; load_net: {load_net[k].shape}') + load_net[k + '.ignore'] = load_net.pop(k) + + def load_network(self, net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + logger = get_root_logger() + net = self.get_bare_model(net) + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + self._print_different_keys_loading(net, load_net, strict) + net.load_state_dict(load_net, strict=strict) + + @master_only + def save_training_state(self, epoch, current_iter): + """Save training states during training, which will be used for + resuming. + + Args: + epoch (int): Current epoch. + current_iter (int): Current iteration. + """ + if current_iter != -1: + state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []} + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + save_filename = f'{current_iter}.state' + save_path = os.path.join(self.opt['path']['training_states'], save_filename) + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(state, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def resume_training(self, resume_state): + """Reload the optimizers and schedulers for resumed training. + + Args: + resume_state (dict): Resume state. + """ + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + + def reduce_loss_dict(self, loss_dict): + """reduce loss dict. + + In distributed training, it averages the losses among different GPUs . + + Args: + loss_dict (OrderedDict): Loss dict. + """ + with torch.no_grad(): + if self.opt['dist']: + keys = [] + losses = [] + for name, value in loss_dict.items(): + keys.append(name) + losses.append(value) + losses = torch.stack(losses, 0) + torch.distributed.reduce(losses, dst=0) + if self.opt['rank'] == 0: + losses /= self.opt['world_size'] + loss_dict = {key: loss for key, loss in zip(keys, losses)} + + log_dict = OrderedDict() + for name, value in loss_dict.items(): + log_dict[name] = value.mean().item() + + return log_dict diff --git a/ultraled/models/lr_scheduler.py b/ultraled/models/lr_scheduler.py new file mode 100644 index 0000000..0ac1162 --- /dev/null +++ b/ultraled/models/lr_scheduler.py @@ -0,0 +1,147 @@ +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] + + +class HandieLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, lrs, last_epoch=-1): + self.milestones = milestones + self.lrs = lrs + super(HandieLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.milestones: + index = self.milestones.index(self.last_epoch) + return [self.lrs[index] for _ in range(len(self.optimizer.param_groups))] + return [group['lr'] for group in self.optimizer.param_groups] + + +class HandieStepLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, gammas, last_epoch=-1): + self.milestones = milestones + self.gammas = gammas + assert len(self.gammas) == len(self.milestones) + super(HandieStepLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.milestones: + index = self.milestones.index(self.last_epoch) + return [self.gammas[index] * group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] for group in self.optimizer.param_groups] + + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The minimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len( + self.restart_weights)), 'periods and restart_weights should have the same length.' + self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/ultraled/models/raw_denoising_model.py b/ultraled/models/raw_denoising_model.py new file mode 100644 index 0000000..d15e022 --- /dev/null +++ b/ultraled/models/raw_denoising_model.py @@ -0,0 +1,970 @@ +import torch +from torch import nn +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm +import os + +from copy import deepcopy +from ultraled.archs import build_network +from ultraled.losses import build_loss +from ultraled.metrics import calculate_metric +from ultraled.utils import get_root_logger, imwrite, tensor2img +from ultraled.utils.registry import MODEL_REGISTRY +from .base_model import BaseModel + +from ultraled.utils import load_CRF, raw2rgb_torch, raw2rgb_torch_grad +from ultraled.data.hdr_util import BlendMertens +from ultraled.data.noise_util_rawhdr import NoiseGenerator + +import yaml +import torch.nn.functional as F + +def sum_img_and_noise(img, noises): + for noise in noises: + img += noise + return img + +def gamma_correct(linear, eps=None): + if eps is None: + eps = torch.finfo(torch.float32).eps + srgb0 = 323 / 25 * linear + srgb1 = (211 * torch.maximum(torch.tensor(eps), linear)**(5 / 12) - 11) / 200 + return torch.where(linear <= 0.0031308, srgb0, srgb1) + + +def gamma_expansion(srgb, eps=None): + if eps is None: + eps = torch.finfo(torch.float32).eps + linear0 = 25 / 323 * srgb + linear1 = torch.maximum(torch.tensor(eps), ((200 * srgb + 11) / (211)))**(12 / 5) + return torch.where(srgb <= 0.04045, linear0, linear1) + +def half_size_demosaic(bayer_images): + r = bayer_images[..., 0:1, :, :] + gr = bayer_images[..., 1:2, :, :] + b = bayer_images[..., 2:3, :, :] + gb = bayer_images[..., 3:4, :, :] + g = (gr + gb) / 2 + linear_rgb = torch.cat([r, g, b], dim=-3) + return linear_rgb + +def apply_gains(bayer_images, wbs): + """Applies white balance to a batch of Bayer images.""" + B, N, C, _, _ = bayer_images.shape + outs = bayer_images * wbs.view(B, -1, C, 1, 1) + return outs + +def apply_ccms(images, ccms): + """Applies color correction matrices.""" + images = images.permute( + 0, 1, 3, 4, 2) # Permute the image tensor to BxHxWxC format from BxCxHxW format + images = images[:, :, :, :, None, :] + ccms = ccms[:, :, None, None, :, :] + outs = torch.sum(images * ccms, dim=-1) + # Re-Permute the tensor back to BxCxHxW format + outs = outs.permute(0, 1, 4, 2, 3) + return outs + + +def tiny_isp(im, wb, ccm): + im = half_size_demosaic(im) + im = apply_gains(im, wb) + im = apply_ccms(im, ccm) + return gamma_correct(im).clip(0, 1) + +def reverse_tiny_isp(srgb, wb, ccm): + raw = gamma_expansion(srgb) + raw = apply_ccms(raw, torch.inverse(ccm)) + raw = apply_gains(raw, 1.0 / wb) + # expand to 4 channels + raw_g = raw[..., 1:2, :, :] + raw = torch.cat([raw, raw_g], dim=-3) + return raw + +def exposure_fusion_from_raw(ims, wb, ccm, blend_menten): + """ + ims: B, N, C, H, W + wb: B, 4 + ccm: B, 3, 3 + """ + wb = wb[..., :3] # B, 1, 3 + ccm = ccm.unsqueeze(1) # B, 1, 3, 3 + srgbs = tiny_isp(ims, wb, ccm) # B, N, 3, H, W + merged = blend_menten(*[srgbs[:, i] for i in range(srgbs.shape[1])]) # B, 3, H, W + merged_raw = reverse_tiny_isp(merged.unsqueeze(1), wb, ccm).squeeze(1) # B, 4, H, W + return merged_raw + + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + tuple: yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + +def yaml_load(f): + """Load yaml file or string. + + Args: + f (str): File path or a python string. + + Returns: + dict: Loaded dict. + """ + if os.path.isfile(f): + with open(f, 'r') as f: + return yaml.load(f, Loader=ordered_yaml()[0]) + else: + return yaml.load(f, Loader=ordered_yaml()[0]) + + + +def load_network(net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + print('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + print(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + net.load_state_dict(load_net, strict=strict) + +class IlluminanceCorrect(nn.Module): + def __init__(self): + super(IlluminanceCorrect, self).__init__() + + # Illuminance Correction + def forward(self, predict, source): + if predict.shape[0] != 1: + output = torch.zeros_like(predict) + if source.shape[0] != 1: + for i in range(predict.shape[0]): + output[i:i+1, ...] = self.correct(predict[i:i+1, ...], source[i:i+1, ...]) + else: + for i in range(predict.shape[0]): + output[i:i+1, ...] = self.correct(predict[i:i+1, ...], source) + else: + output = self.correct(predict, source) + return output + + def correct(self, predict, source): + N, C, H, W = predict.shape + predict = torch.clamp(predict, 0, 1) + assert N == 1 + output = torch.zeros_like(predict, device=predict.device) + pred_c = predict[source != 1] + source_c = source[source != 1] + + num = torch.dot(pred_c, source_c) + den = torch.dot(pred_c, pred_c) + output = num / den * predict + + return output + + +@MODEL_REGISTRY.register() +class RatioMapEstimatorModel(BaseModel): + + def __init__(self, opt): + super(RatioMapEstimatorModel, self).__init__(opt) + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + self.blend_merten = BlendMertens(contrast_weight=1.0, saturation_weight=1.0, exposure_weight=1.0, clip=True) + self.noise_gen = NoiseGenerator(**self.opt['noise_g']) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + if self.opt.get('CRF_path', None) is not None: + self.CRF = load_CRF(self.opt['CRF_path']) + else: + self.CRF = None + + self.correct = self.opt['val'].get('illumination_correct', False) + if self.correct: + self.corrector = IlluminanceCorrect() + self.metric_in_srgb = self.opt.get('metric_in_srgb', False) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('srgb_opt'): + self.post_process = lambda x, wb, ccm: raw2rgb_torch_grad(x, wb, ccm, self.CRF) + else: + self.post_process = lambda x, wb, ccm: x + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + learnable_layers = train_opt.get('learnable_layers', None) + learnable_keys = train_opt.get('learnable_keys', None) + optim_params = [] + if learnable_layers is None and learnable_keys is None: + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + else: + if learnable_keys is not None: + logger = get_root_logger() + logger.info(f'Using \'learnable_keys\' for query training paarameters ...') + for k, v in self.net_g.named_parameters(): + for l_key in learnable_keys: + if l_key in k: + optim_params.append(v) + break + assert len(optim_params) > 0 + if learnable_layers is not None: + logger = get_root_logger() + logger.info(f'Using \'learnable_layers\' for query training paarameters ...') + for layer in learnable_layers: + if hasattr(self.net_g, layer): + optim_params.extend(list(eval(f'self.net_g.{layer}').parameters())) + else: + logger = get_root_logger() + logger.error(f'Layer {layer} is not in {self.net_g.__name__}.') + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def feed_data(self, data): + + self.intact = data['intact'].to(self.device) + self.lq_clean = data['lq_clean'].to(self.device) + self.ccm = data['ccm'].to(self.device) + self.wb = data['wb'].to(self.device) + self.ratio = data['ratio'].to(self.device) + self.ratio_all = data['ratio1'].to(self.device) + + mexp_lq = data['gt'].to(self.device) + fused_raw = torch.clamp(exposure_fusion_from_raw(mexp_lq, self.wb, self.ccm, self.blend_merten), 1e-8) + self.gt = self.lq_clean / fused_raw + self.ratio_all = self.ratio_all.unsqueeze(1).unsqueeze(1).unsqueeze(1) + self.gt = self.gt.clip(1 / self.ratio_all, self.ratio_all) + + # add noise + lq_im_patch = torch.clamp(self.lq_clean , min=0) * (16383 - 512) / self.ratio_all + im_patch, noise1 = self.noise_gen(lq_im_patch) + lq_im_patch = sum_img_and_noise(im_patch, noise1) / (16383 - 512) * self.ratio_all + + self.lq = lq_im_patch + + + def optimize_parameters(self, current_iter): + self.optimizer_g.zero_grad() + index = self.opt['network_g'] + net = index['type'] + + if str(net) == 'UNetArch' or str(net) == 'Restormer': + self.output = self.net_g(self.lq) + else: + self.output = self.net_g(self.lq, self.ratiomap) + + self.output = self.post_process(self.output, self.wb, self.ccm) + self.gt = self.post_process(self.gt, self.wb, self.ccm) + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.cri_pix: + + l_pix = self.cri_pix(self.output, self.gt) + l_total += l_pix + loss_dict['l_pix'] = l_pix + # perceptual loss + if self.cri_perceptual: + l_percep, l_style = self.cri_perceptual(self.output, self.gt) + if l_percep is not None: + l_total += l_percep + loss_dict['l_percep'] = l_percep + if l_style is not None: + l_total += l_style + loss_dict['l_style'] = l_style + + l_total.backward() + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def test(self): + # padding + h, w = self.lq.shape[2:] + pad_h = 16 - (h % 16) if h % 16 != 0 else 0 + pad_w = 16 - (w % 16) if w % 16 != 0 else 0 + + self.gt = self.gt.squeeze(0) + self.lq = nn.functional.pad(self.lq, [0, pad_w, 0, pad_h], mode='replicate') + self.gt = nn.functional.pad(self.gt, [0, pad_w, 0, pad_h], mode='replicate') + + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + # illumination correction + if self.correct: + self.output = self.corrector(self.output, self.gt) + else: + self.net_g.eval() + with torch.no_grad(): + index = self.opt['network_g'] + net = index['type'] + if str(net) == 'UNetArch': + self.output = self.net_g(self.lq) + + else: + self.output = self.net_g(self.lq, self.ratio) + # illumination correction + if self.correct: + self.output = self.corrector(self.output, self.gt) + self.net_g.train() + + self.output = self.output[:, :, :h, :w] + self.lq = self.lq[:, :, :h, :w] + self.gt = self.gt[:, :, :h, :w] + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + @property + def calculate_metric_in_batch(self): + if hasattr(self, '_calculate_metric_in_batch'): + return self._calculate_metric_in_batch + + ## init + self._calculate_metric_in_batch = False + if self.opt['val'].get('calculate_metric_in_batch', False) is True: + self._calculate_metric_in_batch = True + return self._calculate_metric_in_batch + keys = filter(lambda x: x.startswith('val'), list(self.opt['datasets'].keys())) + for key in keys: + if self.opt['datasets'][key].get('batch_size_per_gpu', 1) > 1: + self._calculate_metric_in_batch = True + return self._calculate_metric_in_batch + return self._calculate_metric_in_batch + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + use_pbar = self.opt['val'].get('pbar', True) + + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + self._initialize_best_metric_results(dataset_name) + + if with_metrics: + self.metric_results = {metric: 0 for metric in self.metric_results} + + metric_data = dict() + if use_pbar: + pbar = tqdm(total=len(dataloader), unit='image') + + if self.calculate_metric_in_batch: + count = 0 + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals(self.metric_in_srgb, save_img) + if not self.calculate_metric_in_batch: + sr_img = tensor2img([visuals['result']]) + metric_data['img'] = sr_img + gt_img = tensor2img([visuals['gt']]) + metric_data['img2'] = gt_img + else: + metric_data['img'] = visuals['result'] + metric_data['img2'] = visuals['gt'] + count += visuals['gt'].shape[0] + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + del self.ccm + del self.wb + torch.cuda.empty_cache() + + psnr = None + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + if self.calculate_metric_in_batch and not opt_['type'].endswith('_pt'): + opt_['type'] = opt_['type'] + '_pt' + metric = calculate_metric(metric_data, opt_) + if self.calculate_metric_in_batch: + metric = torch.sum(metric) + self.metric_results[name] += metric + if name == 'psnr': + psnr = metric + if use_pbar: + pbar.update(1) + pbar.set_description(f'Test {img_name}') + + if save_img: + if not self.calculate_metric_in_batch: + if not self.metric_in_srgb: + sr_img = tensor2img([visuals['result_srgb']]) + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.jpg') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.jpg') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.jpg') + imwrite(sr_img, save_img_path) + else: + if not self.metric_in_srgb: + sr_imgs = tensor2img(visuals['result_srgb']) + else: + sr_imgs = tensor2img(visuals['result']) + if len(sr_imgs.shape) == 3: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.jpg') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}_{psnr:.4f}.jpg') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}_{psnr:.4f}.jpg') + imwrite(sr_imgs, save_img_path) + else: + raise NotImplementedError() + + + if use_pbar: + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + if not self.calculate_metric_in_batch: + self.metric_results[metric] /= (idx + 1) + else: + self.metric_results[metric] /= count + self.metric_results[metric] = self.metric_results[metric].item() + # update the best metric result + self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) + + def get_current_visuals(self, isp=True, save_img=False): + out_dict = OrderedDict() + if isp: + out_dict['lq'] = raw2rgb_torch(self.lq.detach(), self.wb, self.ccm, self.CRF, batch=True) + out_dict['result'] = raw2rgb_torch(self.output.detach(), self.wb, self.ccm, self.CRF, batch=True) + out_dict['gt'] = raw2rgb_torch(self.gt.detach(), self.wb, self.ccm, self.CRF, batch=True) + else: + out_dict['lq'] = self.lq.detach() + out_dict['result'] = self.output.detach() + out_dict['gt'] = self.gt.detach() + if save_img: + out_dict['result_srgb'] = raw2rgb_torch(self.output.detach(), self.wb, self.ccm, self.CRF, batch=True) + if not self.calculate_metric_in_batch: + out_dict['result_srgb'] = out_dict['result_srgb'].cpu() + if not self.calculate_metric_in_batch: + out_dict['lq'] = out_dict['lq'].cpu() + out_dict['result'] = out_dict['result'].cpu() + out_dict['gt'] = out_dict['gt'].cpu() + return out_dict + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) + + + + + + +@MODEL_REGISTRY.register() +class RAWDenoiserModel(BaseModel): + + def __init__(self, opt): + super(RAWDenoiserModel, self).__init__(opt) + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + self.blend_merten = BlendMertens(contrast_weight=1.0, saturation_weight=1.0, exposure_weight=1.0, clip=True) + self.noise_gen = NoiseGenerator(**self.opt['noise_g']) + + network_d = build_network(opt['network_d']) + load_network(network_d, opt['network_d_path']) + self.mapnet = network_d.to(self.device) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + if self.opt.get('CRF_path', None) is not None: + self.CRF = load_CRF(self.opt['CRF_path']) + else: + self.CRF = None + + self.correct = self.opt['val'].get('illumination_correct', False) + if self.correct: + self.corrector = IlluminanceCorrect() + self.metric_in_srgb = self.opt.get('metric_in_srgb', False) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('srgb_opt'): + self.post_process = lambda x, wb, ccm: raw2rgb_torch_grad(x, wb, ccm, self.CRF) + else: + self.post_process = lambda x, wb, ccm: x + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + learnable_layers = train_opt.get('learnable_layers', None) + learnable_keys = train_opt.get('learnable_keys', None) + optim_params = [] + if learnable_layers is None and learnable_keys is None: + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + else: + if learnable_keys is not None: + logger = get_root_logger() + logger.info(f'Using \'learnable_keys\' for query training paarameters ...') + for k, v in self.net_g.named_parameters(): + for l_key in learnable_keys: + if l_key in k: + optim_params.append(v) + break + assert len(optim_params) > 0 + if learnable_layers is not None: + logger = get_root_logger() + logger.info(f'Using \'learnable_layers\' for query training paarameters ...') + for layer in learnable_layers: + if hasattr(self.net_g, layer): + optim_params.extend(list(eval(f'self.net_g.{layer}').parameters())) + else: + logger = get_root_logger() + logger.error(f'Layer {layer} is not in {self.net_g.__name__}.') + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def feed_data(self, data): + + self.intact = data['intact'].to(self.device) + self.lq_clean = data['lq_clean'].to(self.device) + self.ccm = data['ccm'].to(self.device) + self.wb = data['wb'].to(self.device) + self.ratio = data['ratio'].to(self.device) + self.ratio_all = data['ratio1'].to(self.device) + + mexp_lq = data['gt'].to(self.device) + fused_raw = torch.clamp(exposure_fusion_from_raw(mexp_lq, self.wb, self.ccm, self.blend_merten), 1e-8) + self.ratio_all = self.ratio_all.unsqueeze(1).unsqueeze(1).unsqueeze(1) + # add noise + lq_im_patch = torch.clamp(self.lq_clean , min=0) * (16383 - 512) / self.ratio_all + im_patch, noise1 = self.noise_gen(lq_im_patch) + lq_im_patch = sum_img_and_noise(im_patch, noise1) / (16383 - 512) * self.ratio_all + + with torch.no_grad(): + ratiomap = self.mapnet(lq_im_patch) + self.lq = lq_im_patch / (ratiomap + 1e-8) + self.ratiomap = self.ratio_all / (ratiomap + 1e-8) + self.gt = fused_raw + self.lq, self.gt = self.lq.clip(0, 1), self.gt.clip(0, 1) + + def optimize_parameters(self, current_iter): + self.optimizer_g.zero_grad() + index = self.opt['network_g'] + net = index['type'] + + if str(net) == 'UNetArch' or str(net) == 'Restormer': + self.output = self.net_g(self.lq) + else: + self.output = self.net_g(self.lq, self.ratiomap) + + self.output = self.post_process(self.output, self.wb, self.ccm) + self.gt = self.post_process(self.gt, self.wb, self.ccm) + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.cri_pix: + + l_pix = self.cri_pix(self.output, self.gt) + l_total += l_pix + loss_dict['l_pix'] = l_pix + # perceptual loss + if self.cri_perceptual: + l_percep, l_style = self.cri_perceptual(self.output, self.gt) + if l_percep is not None: + l_total += l_percep + loss_dict['l_percep'] = l_percep + if l_style is not None: + l_total += l_style + loss_dict['l_style'] = l_style + + l_total.backward() + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def test(self): + # padding + h, w = self.lq.shape[2:] + pad_h = 16 - (h % 16) if h % 16 != 0 else 0 + pad_w = 16 - (w % 16) if w % 16 != 0 else 0 + + self.gt = self.gt.squeeze(0) + self.lq = nn.functional.pad(self.lq, [0, pad_w, 0, pad_h], mode='replicate') + self.gt = nn.functional.pad(self.gt, [0, pad_w, 0, pad_h], mode='replicate') + + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + # illumination correction + if self.correct: + self.output = self.corrector(self.output, self.gt) + else: + self.net_g.eval() + with torch.no_grad(): + index = self.opt['network_g'] + net = index['type'] + if str(net) == 'UNetArch': + self.output = self.net_g(self.lq) + + else: + self.output = self.net_g(self.lq, self.ratio) + # illumination correction + if self.correct: + self.output = self.corrector(self.output, self.gt) + self.net_g.train() + + self.output = self.output[:, :, :h, :w] + self.lq = self.lq[:, :, :h, :w] + self.gt = self.gt[:, :, :h, :w] + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + @property + def calculate_metric_in_batch(self): + if hasattr(self, '_calculate_metric_in_batch'): + return self._calculate_metric_in_batch + + ## init + self._calculate_metric_in_batch = False + if self.opt['val'].get('calculate_metric_in_batch', False) is True: + self._calculate_metric_in_batch = True + return self._calculate_metric_in_batch + keys = filter(lambda x: x.startswith('val'), list(self.opt['datasets'].keys())) + for key in keys: + if self.opt['datasets'][key].get('batch_size_per_gpu', 1) > 1: + self._calculate_metric_in_batch = True + return self._calculate_metric_in_batch + return self._calculate_metric_in_batch + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + use_pbar = self.opt['val'].get('pbar', True) + + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + # initialize the best metric results for each dataset_name (supporting multiple validation datasets) + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + if with_metrics: + self.metric_results = {metric: 0 for metric in self.metric_results} + + metric_data = dict() + if use_pbar: + pbar = tqdm(total=len(dataloader), unit='image') + + if self.calculate_metric_in_batch: + count = 0 + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals(self.metric_in_srgb, save_img) + if not self.calculate_metric_in_batch: + sr_img = tensor2img([visuals['result']]) + metric_data['img'] = sr_img + gt_img = tensor2img([visuals['gt']]) + metric_data['img2'] = gt_img + else: + metric_data['img'] = visuals['result'] + metric_data['img2'] = visuals['gt'] + count += visuals['gt'].shape[0] + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + del self.ccm + del self.wb + torch.cuda.empty_cache() + + psnr = None + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + if self.calculate_metric_in_batch and not opt_['type'].endswith('_pt'): + opt_['type'] = opt_['type'] + '_pt' + metric = calculate_metric(metric_data, opt_) + if self.calculate_metric_in_batch: + metric = torch.sum(metric) + self.metric_results[name] += metric + if name == 'psnr': + psnr = metric + if use_pbar: + pbar.update(1) + pbar.set_description(f'Test {img_name}') + + if save_img: + if not self.calculate_metric_in_batch: + if not self.metric_in_srgb: + sr_img = tensor2img([visuals['result_srgb']]) + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.jpg') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.jpg') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.jpg') + imwrite(sr_img, save_img_path) + else: + if not self.metric_in_srgb: + sr_imgs = tensor2img(visuals['result_srgb']) + else: + sr_imgs = tensor2img(visuals['result']) + if len(sr_imgs.shape) == 3: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.jpg') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}_{psnr:.4f}.jpg') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}_{psnr:.4f}.jpg') + imwrite(sr_imgs, save_img_path) + else: + raise NotImplementedError() + + + if use_pbar: + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + if not self.calculate_metric_in_batch: + self.metric_results[metric] /= (idx + 1) + else: + self.metric_results[metric] /= count + self.metric_results[metric] = self.metric_results[metric].item() + # update the best metric result + self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) + + def get_current_visuals(self, isp=True, save_img=False): + out_dict = OrderedDict() + if isp: + out_dict['lq'] = raw2rgb_torch(self.lq.detach(), self.wb, self.ccm, self.CRF, batch=True) + out_dict['result'] = raw2rgb_torch(self.output.detach(), self.wb, self.ccm, self.CRF, batch=True) + out_dict['gt'] = raw2rgb_torch(self.gt.detach(), self.wb, self.ccm, self.CRF, batch=True) + else: + out_dict['lq'] = self.lq.detach() + out_dict['result'] = self.output.detach() + out_dict['gt'] = self.gt.detach() + if save_img: + out_dict['result_srgb'] = raw2rgb_torch(self.output.detach(), self.wb, self.ccm, self.CRF, batch=True) + if not self.calculate_metric_in_batch: + out_dict['result_srgb'] = out_dict['result_srgb'].cpu() + if not self.calculate_metric_in_batch: + out_dict['lq'] = out_dict['lq'].cpu() + out_dict['result'] = out_dict['result'].cpu() + out_dict['gt'] = out_dict['gt'].cpu() + return out_dict + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) \ No newline at end of file diff --git a/ultraled/ops/__init__.py b/ultraled/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ultraled/ops/__pycache__/__init__.cpython-38.pyc b/ultraled/ops/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1611f1648baf318ed21e17546d8391fdb2e4cde8 GIT binary patch literal 135 zcmWIL<>g`kf?Z*2GeGoX5P=LBfgA@QE@lA|DGb33nv8xc8Hzx{2;!HHer{@BdSz*1 zUb=p0PDxRskE@G*DTtAinxdayP^=#xpP83g5+AQuP 0, output_size)): + raise ValueError(f'convolution input is too small (output would be {"x".join(map(str, output_size))})') + return output_size + + +class ModulatedDeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output, + ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, + grad_output, weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 + width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply + + +class DeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False): + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, f'in_channels {in_channels} is not divisible by groups {groups}' + assert out_channels % groups == 0, f'out_channels {out_channels} is not divisible by groups {groups}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, x, offset): + # To fix an assert error in deform_conv_cuda.cpp:128 + # input image is smaller than kernel + input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1]) + if input_pad: + pad_h = max(self.kernel_size[0] - x.size(2), 0) + pad_w = max(self.kernel_size[1] - x.size(3), 0) + x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + if input_pad: + out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous() + return out + + +class DeformConvPack(DeformConv): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(DeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + + +class ModulatedDeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) + + +class ModulatedDeformConvPack(ModulatedDeformConv): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_weights() + + def init_weights(self): + super(ModulatedDeformConvPack, self).init_weights() + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) diff --git a/ultraled/ops/dcn/src/deform_conv_cuda.cpp b/ultraled/ops/dcn/src/deform_conv_cuda.cpp new file mode 100644 index 0000000..b465c49 --- /dev/null +++ b/ultraled/ops/dcn/src/deform_conv_cuda.cpp @@ -0,0 +1,685 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) { + TORCH_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} diff --git a/ultraled/ops/dcn/src/deform_conv_cuda_kernel.cu b/ultraled/ops/dcn/src/deform_conv_cuda_kernel.cu new file mode 100644 index 0000000..98752dc --- /dev/null +++ b/ultraled/ops/dcn/src/deform_conv_cuda_kernel.cu @@ -0,0 +1,867 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/ultraled/ops/dcn/src/deform_conv_ext.cpp b/ultraled/ops/dcn/src/deform_conv_ext.cpp new file mode 100644 index 0000000..41c6df6 --- /dev/null +++ b/ultraled/ops/dcn/src/deform_conv_ext.cpp @@ -0,0 +1,164 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#define WITH_CUDA // always use cuda +#ifdef WITH_CUDA +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); +#endif + +int deform_conv_forward(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_forward_cuda(input, weight, offset, output, columns, + ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, + deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_input(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_input_cuda(input, offset, gradOutput, + gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, + dilationW, dilationH, group, deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_parameters( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_parameters_cuda(input, offset, gradOutput, + gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, + dilationH, group, deformable_group, scale, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_forward(input, weight, bias, ones, + offset, mask, output, columns, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + deformable_group, with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_backward(input, weight, bias, ones, + offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, + grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, + with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_conv_forward", &deform_conv_forward, + "deform forward"); + m.def("deform_conv_backward_input", &deform_conv_backward_input, + "deform_conv_backward_input"); + m.def("deform_conv_backward_parameters", + &deform_conv_backward_parameters, + "deform_conv_backward_parameters"); + m.def("modulated_deform_conv_forward", + &modulated_deform_conv_forward, + "modulated deform conv forward"); + m.def("modulated_deform_conv_backward", + &modulated_deform_conv_backward, + "modulated deform conv backward"); +} diff --git a/ultraled/ops/fused_act/__init__.py b/ultraled/ops/fused_act/__init__.py new file mode 100644 index 0000000..241dc07 --- /dev/null +++ b/ultraled/ops/fused_act/__init__.py @@ -0,0 +1,3 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu + +__all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] diff --git a/ultraled/ops/fused_act/fused_act.py b/ultraled/ops/fused_act/fused_act.py new file mode 100644 index 0000000..88edc44 --- /dev/null +++ b/ultraled/ops/fused_act/fused_act.py @@ -0,0 +1,95 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 + +import os +import torch +from torch import nn +from torch.autograd import Function + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + fused_act_ext = load( + 'fused', + sources=[ + os.path.join(module_path, 'src', 'fused_bias_act.cpp'), + os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), + ], + ) +else: + try: + from . import fused_act_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class FusedLeakyReLUFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, + ctx.scale) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/ultraled/ops/fused_act/src/fused_bias_act.cpp b/ultraled/ops/fused_act/src/fused_bias_act.cpp new file mode 100644 index 0000000..85ed0a7 --- /dev/null +++ b/ultraled/ops/fused_act/src/fused_bias_act.cpp @@ -0,0 +1,26 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} diff --git a/ultraled/ops/fused_act/src/fused_bias_act_kernel.cu b/ultraled/ops/fused_act/src/fused_bias_act_kernel.cu new file mode 100644 index 0000000..54c7ff5 --- /dev/null +++ b/ultraled/ops/fused_act/src/fused_bias_act_kernel.cu @@ -0,0 +1,100 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} diff --git a/ultraled/ops/upfirdn2d/__init__.py b/ultraled/ops/upfirdn2d/__init__.py new file mode 100644 index 0000000..397e85b --- /dev/null +++ b/ultraled/ops/upfirdn2d/__init__.py @@ -0,0 +1,3 @@ +from .upfirdn2d import upfirdn2d + +__all__ = ['upfirdn2d'] diff --git a/ultraled/ops/upfirdn2d/src/upfirdn2d.cpp b/ultraled/ops/upfirdn2d/src/upfirdn2d.cpp new file mode 100644 index 0000000..43d0b67 --- /dev/null +++ b/ultraled/ops/upfirdn2d/src/upfirdn2d.cpp @@ -0,0 +1,24 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} diff --git a/ultraled/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/ultraled/ops/upfirdn2d/src/upfirdn2d_kernel.cu new file mode 100644 index 0000000..8870063 --- /dev/null +++ b/ultraled/ops/upfirdn2d/src/upfirdn2d_kernel.cu @@ -0,0 +1,370 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} diff --git a/ultraled/ops/upfirdn2d/upfirdn2d.py b/ultraled/ops/upfirdn2d/upfirdn2d.py new file mode 100644 index 0000000..d6122d5 --- /dev/null +++ b/ultraled/ops/upfirdn2d/upfirdn2d.py @@ -0,0 +1,192 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 + +import os +import torch +from torch.autograd import Function +from torch.nn import functional as F + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + upfirdn2d_ext = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'src', 'upfirdn2d.cpp'), + os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), + ], + ) +else: + try: + from . import upfirdn2d_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_ext.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_ext.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], + # ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + _, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == 'cpu': + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + else: + out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/ultraled/test.py b/ultraled/test.py new file mode 100644 index 0000000..a9bdb85 --- /dev/null +++ b/ultraled/test.py @@ -0,0 +1,45 @@ +import logging +import torch +from os import path as osp + +from ultraled.data import build_dataloader, build_dataset +from ultraled.models import build_model +from ultraled.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs +from ultraled.utils.options import dict2str, parse_options + + +def test_pipeline(root_path): + # parse options, set distributed setting, set ramdom seed + opt, _ = parse_options(root_path, is_train=False) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # mkdir and initialize loggers + make_exp_dirs(opt) + log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # create test dataset and dataloader + test_loaders = [] + for _, dataset_opt in sorted(opt['datasets'].items()): + test_set = build_dataset(dataset_opt) + test_loader = build_dataloader( + test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") + test_loaders.append(test_loader) + + # create model + model = build_model(opt) + + for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info(f'Testing {test_set_name}...') + model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + test_pipeline(root_path) diff --git a/ultraled/train.py b/ultraled/train.py new file mode 100644 index 0000000..df884a3 --- /dev/null +++ b/ultraled/train.py @@ -0,0 +1,222 @@ +import datetime +import logging +import math +import time +import torch +from os import path as osp +import sys + + +from ultraled.data import build_dataloader, build_dataset +from ultraled.data.data_sampler import EnlargedSampler +from ultraled.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher +from ultraled.models import build_model +from ultraled.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str, + init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir) +from ultraled.utils.options import copy_opt_file, dict2str, parse_options + + +def init_tb_loggers(opt): + # initialize wandb logger before tensorboard logger to allow proper sync + if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') + is not None) and ('debug' not in opt['name']): + assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb') + init_wandb_logger(opt) + tb_logger = None + if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']: + tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name'])) + return tb_logger + + +def create_train_val_dataloader(opt, logger): + # create train and val dataloaders + train_loader, val_loaders = None, [] + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) + train_set = build_dataset(dataset_opt) + train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio) + train_loader = build_dataloader( + train_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=train_sampler, + seed=opt['manual_seed']) + + num_iter_per_epoch = math.ceil( + len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) + total_iters = int(opt['train']['total_iter']) + total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) + logger.info('Training statistics:' + f'\n\tNumber of train images: {len(train_set)}' + f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' + f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' + f'\n\tWorld size (gpu number): {opt["world_size"]}' + f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' + f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') + elif phase.split('_')[0] == 'val': + val_set = build_dataset(dataset_opt) + val_loader = build_dataloader( + val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}') + val_loaders.append(val_loader) + else: + raise ValueError(f'Dataset phase {phase} is not recognized.') + + return train_loader, train_sampler, val_loaders, total_epochs, total_iters + + +def load_resume_state(opt): + resume_state_path = None + if opt['auto_resume']: + state_path = osp.join('experiments', opt['name'], 'training_states') + if osp.isdir(state_path): + states = list(scandir(state_path, suffix='state', recursive=False, full_path=False)) + if len(states) != 0: + states = [float(v.split('.state')[0]) for v in states] + resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') + opt['path']['resume_state'] = resume_state_path + else: + if opt['path'].get('resume_state'): + resume_state_path = opt['path']['resume_state'] + + if resume_state_path is None: + resume_state = None + else: + device_id = torch.cuda.current_device() + resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) + check_resume(opt, resume_state['iter']) + return resume_state + + +def train_pipeline(root_path): + # parse options, set distributed setting, set random seed + opt, args = parse_options(root_path, is_train=True) + opt['root_path'] = root_path + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # load resume states if necessary + resume_state = load_resume_state(opt) + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0: + mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name'])) + + # copy the yml file to the experiment root + copy_opt_file(args.opt, opt['path']['experiments_root']) + + # WARNING: should not use get_root_logger in the above codes, including the called functions + # Otherwise the logger will not be properly initialized + log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + # initialize wandb and tb loggers + tb_logger = init_tb_loggers(opt) + + # create train and validation dataloaders + result = create_train_val_dataloader(opt, logger) + train_loader, train_sampler, val_loaders, total_epochs, total_iters = result + + # create model + model = build_model(opt) + if resume_state: # resume training + model.resume_training(resume_state) # handle optimizers and schedulers + logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.") + start_epoch = resume_state['epoch'] + current_iter = resume_state['iter'] + else: + start_epoch = 0 + current_iter = 0 + + # create message logger (formatted outputs) + msg_logger = MessageLogger(opt, current_iter, tb_logger) + + # dataloader prefetcher + prefetch_mode = opt['datasets']['train'].get('prefetch_mode') + if prefetch_mode is None or prefetch_mode == 'cpu': + prefetcher = CPUPrefetcher(train_loader) + elif prefetch_mode == 'cuda': + prefetcher = CUDAPrefetcher(train_loader, opt) + logger.info(f'Use {prefetch_mode} prefetch dataloader') + if opt['datasets']['train'].get('pin_memory') is not True: + raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') + else: + raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.") + + # training + logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}') + data_timer, iter_timer = AvgTimer(), AvgTimer() + start_time = time.time() + + for epoch in range(start_epoch, total_epochs + 1): + train_sampler.set_epoch(epoch) + prefetcher.reset() + train_data = prefetcher.next() + + while train_data is not None: + data_timer.record() + + current_iter += 1 + if current_iter > total_iters: + break + # update learning rate + model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) + # training + model.feed_data(train_data) + model.optimize_parameters(current_iter) + iter_timer.record() + if current_iter == 1: + # reset start time in msg_logger for more accurate eta_time + # not work in resume mode + msg_logger.reset_start_time() + # log + if current_iter % opt['logger']['print_freq'] == 0: + log_vars = {'epoch': epoch, 'iter': current_iter} + log_vars.update({'lrs': model.get_current_learning_rate()}) + log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()}) + log_vars.update(model.get_current_log()) + msg_logger(log_vars) + + # save models and training states + if current_iter % opt['logger']['save_checkpoint_freq'] == 0: + # if epoch%2 == 0: + # logger.info('Saving models and training states.') + # model.save(epoch, current_iter) + + logger.info('Saving models and training states.') + model.save(epoch, current_iter) + + # validation + if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0): + if len(val_loaders) > 1: + logger.warning('Multiple validation datasets are *only* supported by SRModel.') + for val_loader in val_loaders: + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + # pass + + data_timer.start() + iter_timer.start() + train_data = prefetcher.next() + # end of iter + + # end of epoch + + consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) + logger.info(f'End of training. Time consumed: {consumed_time}') + logger.info('Save the latest model.') + model.save(epoch=-1, current_iter=-1) # -1 stands for the latest + if opt.get('val') is not None: + for val_loader in val_loaders: + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + if tb_logger: + tb_logger.close() + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/ultraled/utils/__init__.py b/ultraled/utils/__init__.py new file mode 100644 index 0000000..1a29958 --- /dev/null +++ b/ultraled/utils/__init__.py @@ -0,0 +1,53 @@ +from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb +from .diffjpeg import DiffJPEG +from .file_client import FileClient +from .img_process_util import USMSharp, usm_sharp +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img +from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger +from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt +from .common import AverageMeter +from .process import raw2rgb_postprocess, load_CRF, raw2rgb_v2, raw2rgb_torch, raw2rgb_torch_grad + +__all__ = [ + # color_util.py + 'bgr2ycbcr', + 'rgb2ycbcr', + 'rgb2ycbcr_pt', + 'ycbcr2bgr', + 'ycbcr2rgb', + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'AvgTimer', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # misc.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'check_resume', + 'sizeof_fmt', + # diffjpeg + 'DiffJPEG', + # img_process_util + 'USMSharp', + 'usm_sharp', + # average meter + 'AverageMeter', + 'raw2rgb_postprocess', + 'raw2rgb_v2', + 'raw2rgb_torch', + 'raw2rgb_torch_grad', + 'load_CRF' +] diff --git a/ultraled/utils/__pycache__/__init__.cpython-38.pyc b/ultraled/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e360cae30678a71197e5168c790534d5fd253d2a GIT binary patch literal 1395 zcmZ{k$!;Sz5Qe*DEv?O3Y|EQ$OWrjTM1XP)f(&K}5R8LhJagfNXmwHTj#*8C-Lm7n zLLMNeJV2f(0Ixaa0dmPHEXwjs0)!2Hs$#K8{=Z0lXf|t}Z9o6=*Uhgb&-;UtPutKY zU)xXC^F0qd@TG@7^Rd7RxW?A7$ck8EC7)^qS;h*hxOhzlSY=fg7iA6Wtd0%VaCu2K zah0{V<3i&)YRac+bk*m;_N4^POC9~PL-@pFG#zh!|O?8nW&nUEu)g!Iue5y1EZR!cmw8H5_E;3U~FrMTh&6GBk1n+f{3sYNmOl)mb zHo=R#_FCCE|CRDshoB4g2i0qYBEjAaPnQV+p-E^F8iaL1ozNk)37dp1!UmyB z*e7fgb_qL#JwlJrCsYVk0$odeKsY1}2y_|sG2zJKyU}x2i|-c%m9wRC}G|Hleb;XIU*^WQvuL1n_G@!Qw0#tWJ2Pzo56rfkgZLANKGe?EUO!I!9DA;nj8owN5AeZOI! iw%_*${;7@n#e)C%6m0t#uOk1GUg}$p`Z&@r{QeJY-(n{K literal 0 HcmV?d00001 diff --git a/ultraled/utils/__pycache__/color_util.cpython-38.pyc b/ultraled/utils/__pycache__/color_util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8782c43e40501bf6b745ce33c83bb7501b4f627 GIT binary patch literal 7624 zcmeHMO>7&-72ZGoh?ZnIu^q>b-Lc&^l0%V}Tqmg$#ERt~B#je8QGV>!m6n_#xi-1W z&JLxjpn$Ya0u*geJ>-zIka8(>ErJ|^0@0}m&|A^d0zq07D4|o)rawU1!hP>8xm-&A zQDLPhYE$Cl?#}GYn|brzeD9l=1_m-3e17oOUn`?u(6ry;b?DG*%adR4om z)30&K;5)dn~~}JJ+CH?fJ7s1Ap5e3x3n^orJGn0@2j2Xp4r@o7y#fQETe*h%%PUCaA*=ZkEku zgIf=1i>8}{=N3H&b;>4`f;#YP@+9vohcz>x(zddR^;M#2HjQh>qS-V-nST3K81JO# z&IzfQ&8E(sVouF23#L4Fak4Ne!`oaY!{9SraeP*4xJBi7E~`tAH;Te%sv?+g*94Qo zZ&a1fY@eOCRn;y$?6X-}E{rb}3q?5_y_%!eo>9d6yq7H9Lx(WhV{z5?BS+a|Q^Mwk1b>~oKgeLS&= z_?Yd9kEhN}h0bxa_$XMREEz;oid$ZZMWb%Je~-`qAI3bKp?fj!?u~g8p^OBl(If&K ziOw!lC=hl@1Ua$~MH0m}6#Xa$Q0zp3K_(P?;VS3@hYIn-m*Mu-apNJVQPe?nKBz~8 zs}Rr;uF^wX#r2)qqSnS$1|TPcaFc~`0nk$laTX!1^iB=IT^mCg%4!--9Wa#9Gzdd& z>xmID*o09cf>d{ap$I;WLnw*R#$6Ik2su$afoZ&lpmMiEPy{UH03I;y!I&c>#HSp&n<_$9fKPkUMGD-I zZ@1IZFz1ZbKqLVs287 z0uwgBZWqHOw@ZRVu)}A5qgX+HC(3IWsu2r4G?M>qp5ey3F+!?jb;bZg_`jzUZcucLBv}FJ9Vg)EGqyNQ}z@aNz$$tByAtx zq>V2~wbw$BAR;E{n}(t@QhIX4i0y>E-a;GacX7Gn=BW+g_LtEO_-2ejol1N1?u{Oz zEA0w?jgoOSB@ZZ*>#_AurJ;u`ShcMXK@ zcWcTETc=hw^w2gHQGPepH@2WiWn694R+m+~jdy##io17vBuGj|UL7MBsqGGvyLDdO znf{)+BhOhsK(GnvK&}F!J5FQ|yIf$$SU^;ED92{ZVSpY4lSY-|1_ge#r-z$P{3fm1 z(R5vB|FLa*lv>A}qK&D11BTyBbh^o+rv}5|6Y0MB;6tGcA`qeNQ3x6`iy{XiNJy{Y za`^~cw}VcQ?sTzkpC*j{4pj6p-Pd*v>4Ow?(mIv&L>=S|qOAQRxPkMJXBF(qTrb1< z2Yk)44w<>WWHdF{r41;9vyml>L+*}%dl@qHWrOZ9k54RHJWZ!1GHeC2dC2h%JgT9b zKf?zIq{my-`5+wJY^Qz_&1;dfL-Ay3c9YAs1wPMQW$8M*7 zsB%b(A?OT=GE9!waF<=fXKQ#HQf5St5pRZ`5K?Bupmiy;!%#2+=Ku;9N|01Ezkva_ zm=p&oKnk^n5`i`6i20F32nsOm4U+T1DOVJw$$?dHY=51%2K^Dfo2jc0;GGaK3|RU@ z`VnJTKco*E;JQzs3i{dhePQAno^YTB&+=ZS#7T!s$j7ja2E|*XOA% zjm}FYQ)}jBn-)z~HTBWr+QmZ5>$wpE`D*c0LRXKsr#ZUI_Z4z34XSSzoAvuvNO$r{wq@B^l24W`V+&i7L0ZdR?`_ZR?P(w9 zo;gXk*&`&9Z3u`!0+@#bdq65EA%p-)AXQ)>p#mY5auo@rCR9R^0!qa{sY)fOA}+}9 z*E2i2d$(uVj!7|F^G)|RJ-yT2->bjxn@3iyN@+OW`RspA?|Or#{TnNt#~>>Eak@uM zO%qyH6S^?UdRA9gBWtRwm9^A0R*7fhI@_4#SS69QbuQXp27l z`o$_SAXbY(u?CoaF%-5~D~827g;_5)h>dt+lh`b_0J9bKtHe5SwRoLKi)(;Mh-<|f zah7r>GqB@DyJ+LBQexLdW2P_L@5Gy{ z9IH9t^e^g*I_l=kYO%Z3B0J~|HrI&pMV(`5CdGQZMMB(Cv~ax@?|&Og30&8qXG`27 zZWH4y6|F^mrq3A^hS>9z5x^n2BheFIIVD1(70drr7 z*$m8KaX&EkhnOwEyiq&=%mX22D=-g=BfuO9F;@ZeCh-t34~3Ykfq7UQ1?Fgoc^xpv z#BpGbhnO@lkBBz|^X3q94KQyJCxAH-Vy*?|q{sr34Kdd_*NaC_8}BjY5k+4&WN*Or zRQblrP1&0&BiRwG$qmhncx+J@k8jbk+nm=sH_U85uZvvK$nLZ{dV-Ov_By zKbm^5CQ4H!N2I4@t&*Phyjgee*jTaTP0vk^7HXBTO1W04mc07d!P3;!{SV!9Zz^@D zCex)#zUa8=YRyZF(wUN5nk+l%$@%o2RI3l+&W^fXzHr((SD4OMi;n8<$vMaMO0}vx zw(FK%yKWzwnRSXgM8VrJn-`)~EfT$Z*L4b}E>jvD$>@G+YOY$~WBKxBN`=dP1e6|^ zHxi5x>>_|1(JuEB3=pg(*i5jO;7)?u2v!l?PH+#w9)dLl>j|zRcniTp1a}b(5!^s< zkf1_PBzTbEI)eQK4--riJV|hh;AVm&1oH&v2pod@2yz6I1hWK!pa5`5_hZ7FpLH_E zWgB(5n|HKqL^ zIO=$Fvf9yydUkJmqTMollhI&m$1xs~lBT`=^ZSE>R@#4=a}m}~?pJdRib#Ynr0M}6 zdHQQYe9o)o!eIjKl5Vu5t*4BBU6ODMM06gURUYmLfTKaXYFR_*S@Rif3i?)<7i?%) zXi;d0*fW}?W#gPzf7Q`@?;bxnF*zac$28Qh!SJeQ9O=1f+z8}D8dIxUs9|=kC&!k+ z$$xjJ06fhDg`PK>#-gTco;jnmx-g3-|69ayjZ=pgfx+FFNJ4|g7j!Stw7q0AB~sMu zMU$GXq2>BQe18LVg;^y?5B*1q?>O+pk6yU;JNt^0gC}nP!Z&}oulTJqU-{VPu`ll{ zHh2HS$n_uk#`DF`{^Lre?9UANF?Vj(k&=SNw{y7?#Ahz&8`W7qmX|U=?>Bs=<3LyaO&e5*kDCi7llI(;s^ z3+J>`Do%SSoGl4&I$$h5vW&Ti40>pH(ETubj9x-gKFR|GO8m2m3A9A1;${+l0xMUU z!%C=yNCpcbYMz>CNp0X;XG+dlu1$W{xn#=Y#KqJ|{5S*wg&pN#dm*siHF-8K1=Wws zIo0ez9m}w;dl|x3mlC9!DhL8plgi~PH8EFaxj&a9lZST_xtypKu;5%~nFsJ!g;Y>* z%H`y3M0l2P%%g^JLiiI^3)1cP;vAIb%UzsgKkJK~i z0GD>#$aJlmrtC=5HmEvfce#k*%J40Cy41eiB7*FENk5V?<-NEkrwKR}mvprxQjJ-S zUJabLsHp=Dqt`btd(Yl%KvkN+Il*VY4$#d8O??q0?-?^@!+@BI!Q3{(tEx>|o3y4S zAeJD;lmNSEEtp;HDc<b;9Dhv#CXok0UMSBACkInC6PH;$;l~Q) zTGjFOLw+iMrkIM{SlTXB=j4lst;%O~ow*RJ#r`8jn}X@&k95YlYLt z!0et!e!n&yl|wtARD~rlOX==hiP34#FJW}*b5m(=GgiR!b7-r0-q%md8CH5Q z{S-$;E14RwzM=bFd-FW`J>q%JE7X}%eNP#iv{S}{0S+X_+@vjl_h4RI4Va-b#%36v zGlsAT7h?;tMyzQ<9#Cn*gaIbrh&Qc88sspk*h;muz>|skGZAsWN}ha~N-|Qj)7#FD z-MSmhMV)&#`uBi46rWH8U;YJ-fJwMZ0wF(!i{!pWo*__-kSf%$e<;q?=AcE}r0&N9 zvEl1bWlFsVoHY<|0MsS5sPG4aI<^T_*yXMxIw)C*?>zpAE9@$lkv*iR^sV|3ESV+i zc|~p@&2UNE++a!KUNdfpADlS8=iHw9_1IUr1qrJTma$u^I_cva9bOP795xGTzPoH< zlAMi`eCYldpqmc@n^@Vt#zh16S9~Vn*}{a!0V?F8aZr2g#3_r!xnM2CJo<~`4P#gn zi49r<`yxDag? z7O>lGu>tQ$T7JGh@9T5Eeg@ls^6YfJm#|tuJC~jd3addF>__ypzFx1d4v4C^xZ)g8 zwol-6cLKm1(6I}qZH`@4WTnG+xQG>37COY%OpJcT20cO6&z_Fh|CcL#S|)zCH4**! z<+tJWgXKMe64dOwiJ&nqA0<$;sbr6mLXQ)tj64M*QSmhJ!QR)lVQKg4^?`5(uP7i$ zF7(p01w>L3SHRhIaN2pe+W=PFI43CiFX$axXWLJngckwgvJML&O`jUQ3al%^pj-xBk#i%;!8hjxHI2-2VXE^iaef0bWfiX*;ghhoV`C$XfAqp*68+ zuRK}g>IptfOU47mdb(#!8RrdoEtD_(1)y`7-^26+>D4{hf5LYFsl**A%-@Eo_N*|R zaVM^7@QE)%A}Zz0=Pf?j(3%O#FO;xzg-H5c{9Rx)DtnQ8;Ee63;Owlzu~{uZNWid> z?_#rO0U#xm)##j?%~wUBAvqbBI&42b9yn)weG;xTPFpYXZQ`jw4A=wuu#wWQ1qrA0 z`kMBXg>;ToF}@r zdWt5yz9>}NA3E2qkOY6c$HE^ZCrBl{ZxU@C8d;k zSmY4tL-;|_uIrrjmtct`o{i!ZZ$kz0S6ME!5Kozb6&u)~!g$V9TPHYMVCC*#NT96& zi!!l^)~DS_2=jDY9#wk4h9ky;eQZ|;P`4UU9VI*gYj&g&7csa-;v%jfdQol>yZ))j z5b2gqbWo@r%5sOpEA~~2Qo(C?)0!+q-~w+6{GHzP@-5mPY|-`v;R3$Q-BI;iwe#~W z#YjD*#lN6C<+9W4VbDZBsKeF=jR7NN)Q4L`CRdJj3yn)jNhu%&fkQAwpv=EJSfc5n z+#|BcQv0vl3Xe?i0?tUh3?(FS%ft)WlE*7oczq`ig0n&kZd1FQZd+|!q^`{!x6=Ga zp6*6~rNV$#SKHnSJsh7$!(h-fFjn+&{0}>|0F;koxa0L+f&nRhBh%{5sC1cDR|l?r znpH5A%&fiA_D5TI{j zLnWJdHMvO8UT$UD&?i7X2;5|_lm?ZhWb5^{oior)Vk`T}IM@Hp?n1p}_FIy;*OG81 z+Qd@d+Aepw!UKz)^wK(B9qs%%9uH{e6D-XT^sr}!(R&B&biCoOEbXZAbkmOf0*BXw zKsIuCDgf3+Af3Z}xdg%q9b|VM1VY@(1TuoVorX(fvh)tACQ702==FygN_8FNR{S4a zNI%5CFO@Rf8o|s^u>j@9G2zC+R!F&V;8JSBrPQROCj}lx`UspgeW5$29}c8d0*t z3kOu_erm*(1FG73vqt9|1@5E1Bl0&f(hjH4c~*a&pxc8a-$tAX(DDua1rbWA&^M%f zJ@TC_^>Fpbi@c$Bd%wfdQv`G)$)^dPA$U8%I|zP@fZmFhA4ol;+!uH9meR}bVd=dD zJs5KbNv-7zO2Zcv_=D<$A(cezd8M!|4A)lZVVLrZXc7?IO7i$Vj?+Lh4YqXWYqi%~X?nK6!5NLT z_yed2oo)?8ClISZ1cK3t7`VwoYyu8lK&xSekqBz+FuekpL?h8ORcwOo%|%8E6ZOA& z36nUC=(Mcaw$-X)_uEwD9ubD^ON03|j#@|H2pqv2XQF@NJ@s&aj_^IhjK@{>z)G?I z=-I%JjJXXR$11*YfYm(!h$ubAvf2n|Svp6otqhWEA{?lsU(k=KR|{evzs9Bwf(3#e z`sa;Y&J=EHDGM%~%?w%$cP;A`sh>1+kljRj$CSjqX4F3zL=_*IjY9TY!mIFV>p}xQ*o8OyG3)10d24%DzBOfOUFa zhszr=bX-kOn}J?yVvC5FIFgOj)eNpds?kL{#1r*zM3l39E7%d)o^SbSgZ*Lj5J|W* zq+OV7HSBKE5yCD+$q+9ykk{Lnj*a53=s%S3?$QsG9h|FUINd`4)DDbLcEmOjpJc`Y z)95fW&>PAYuv8uKOh^YE8X7TBY}HL-8A3SH5jQ-Gep@W{B+BwP2*_N(D)o1U2(pv> z62TJ2Ve=~KKrs$jCI#h(Fv!SKLT&9INAE-kI$w708Vw^vWbcgbn@H<{ljHzvW{}fk z77#$PsbT>}Ts6S9+V36&cC@6g%o*W#V2lR$is@ z?Z%Nly z<4BXmJO1@e7oV;E*FHs4#rF-L(2m8w`+V`5&ph|#UH*@sFWz*s{pqdQ7xxA3`Efpi z6UP%@%?t$2urH&w)U;F98kOWy`%c+q5Anvs1V;(TA-(Lfw?>4Z4;b5+$F;s|(%Ujh zNip=ogVV`HS2BtY%|H@?`wUpeOrir1UNk#{bIrKY&dBXqWQIVze&tlUD`ea#*$@eb zm!(d7__s(hJOYM4fw6R?)*L_;ChuNiA0<$;swPw^*T;y<5y!vByJgpwT!4&>s=l==yZCr>Qin1vT&4REJzCeWqX)OSUlak0lKYE z&pZ#($|)~1u5mA+>d2Q>S9q7z6`o~vRbJ&F+!s2O>GcnIi2Gc>zmOA=<4gHQ{Qmh) z6m4?fDj}a=87O|D5?IudGR}_&wuD*^RBGj#bXN?se2{C{+Wk=62EY1-?rD#hO%Y>( zf~$3`LV?z?;<$f80a^@ptR$8>^ z;Pa#o-*abbUn>zvd6}cG$m69-aeBFN4(!?G*dVeQX3D=p8|J0N%%`z-@{b8VLqG|$ zRB8VaaG6!|Pk8Hjf{ze_8z+TW?`EXx zUk{bCgv^)ts6N--b=&(*j(sot55-@6C@G06;Osgg@wW}=P9K#(2n~t<*cV<%BJrPL zwa&zU+qYX^0@Z(de7fWYA+;!smJU*&)41W5nZjM3hZigSz$Ey}sBmUHh-9=nN+V%@ zw#H8f0+SZ0I0XvNYBBtxqMN9r?$q}xBa_JA0Wlw*LW0$tvzI ziS6mqRQeDM$JTQU>6yMH7dIrWVY1JR- z9yh0Fe!$vH&2uFo5!_5a0B!z_p%LIfn(xKO5 zV{+P*o-vzv?0dF%Kj`2|la|rP%CR-y%4awHbg0dm{O<`X&5LZl6 zd;^yZ#4?;C%#VvVY&t2?Cawf4wmB-5n-C5Z1>;W{OsVF5C3HU!Ny={eq})80ZhzKr zxH>!M1*Z4*^kiZB?Dn)QPHs;>x^sK_=EtH=_!y!}{19pR*G9pIOOOJ~J~~Rf$f~V4 z?bx6)#xDOHS)0sm#zKeY8JChIq>@*>489=2mj#X=JCiTZIZCvG3O~%|N_f7DrGo^@ zsrejBH2~j4{YRY}a?I#+kiA%=2l!P)vRFBM-P5MS>lcma6jY6oXV^>FF(PuHB~B?)dzNJb^SO#0mb(v?F@l-c0iEs^>KDX z`J*g8NlQ>UUzLfcEIVblt-96o*$Uqx5>DIo|5UrA-{5%Rv&zooN2z|Q>YU95>nau5{uy7bt*uq*F5rCo&fp82 zyU0rJG=LQkj8RabY3aA>LmGaBKQ~P4MLWh+{gi%CA2jW_y=G|R(59gcLtBQf#_uMS jH=(={# zzx&HD`wJJd&4R^!=%NpV8H~h+SB%6XdTn~99xcz(qdhEnWm4>^3>~iu^EqZQ`zi5U zR$}F+hF4<^R$&fs>a5CK;LNcatHan}WXEXEzlYn5wz89MH}0Cam^9P}`y=S$XCTrz zAdE2cIcPI*_*H4VF`4z25Fr`bg=nle!sq)|`;QQj8!5e_gKi@13!({(B3zGdM8G`6;%EG11^|hremOO|< zwv@>z7EAhGoVG^Cs(PjhdWJwr23a63aYzHZ$rpgrCCHyN8%FH8Cj9f8J z$=H5Nh?I~^JcnxvA7QP_=o zQOIsk*$)#WbQlN;QG^@mQ545?4=y|g-cXC1xqu!9JPP*W5ad2y-Pv8=_MdER?{+#Z zy!0GzdS#USpR|cC0cBw?fLJM1>2z*?duMqWbe}4fiI7cfT!py1tJOn`Cjm*X13t~W z^t&QHY1*UOYg;qGwayM{d#{6|-=OJ8MrjhnaQ5*BFj{o2&_{q6+J=-=9QNuvn?Gzn z_MdETJ=j>AuG+h+JCAm}M)9w`wYI*wv9o(}S#;Ax@-(KRpB~YEDr6E2LwZu9Iqio5 zqc@XVZB?6r(<^Di&v>hAYun2GX={6P&EMJZ)|I=t^A%No1-j<3w56&Z2|qXtq8MFTHDHB#s|Y!!94sU{Q8jROl=7JQ0-_O$L2YIu{~Sar_rw2l zK&Ux9u5*5r^iniM;5!+R-Cffr_Swj90{5d#j{sn%M=%6F41(|sa`|sYPnY4R4rF}D zTr1LwpC<8fYb8#*K`idICYzh4s-hg3 zrm;{ZN YyoXu}j7z-=q2IpM}nCJZcGzFdED^8m7-0matO_k65FCciS)A zGw@#Yi#Vs0T)yKfu@y(sZHqguw<|(__?Ok}Ebet(rFhihDqX>w=@M11RylsF-R^jZ zheG%~j(E`PcSTQxF^~5n?)C$F)1DKHkM&SpZ#7vVh3hLGcLZ+* zUBO+&UEVy11&F{@?}8op=}+{2_>2w!r9`M_f-U)v-Wi@XJjzDn8(YQ$w)wz#U?#Q_ z-Bwbjqx){$kwwhw=3g%@<5KN|fg84$W1(Wli+r*C#w+W~gDynV75?%d4!UZYzID8A z02b8y2Z@EN&(Tqu26K4rJj?4?^5036q>3p)rSY4lk2E+bV-wBP$EFT@;^M>fx=qNx z>^Py@6ONNq9j6!h1CZ_{vySu8!0l!;CCBk2&v7IvnVg|!mYOkZ(l;VP5<7`ZMZGe! zOD3yYCe!0JmQ92 zdYSt;i>R_>it~ zXpih2c5IHU6UL0;tRA%v3-24@leQ7tyH&Z)U``D0zIntxVZSg!^8;hVMut-unP`in z!cOV9JTi6)c#0=9Rv8tJt0xBDX2|+#)}|zr9RBd_u7F44L8xLE-etmSTv>!}6CPnK z*-YH^+H7Ls`iV{VPl_p960_$eGq3mhk&IuFGLp%qGvZCEDfdZIjZ5QQ*R=Q=Gp~yH zb)q8lL{iG^xl!&P>$_`WN^d|nU)wk z1lM=WV~ganW1k>)jm%?`hke2ht&s&8n5hhEMQ}5gDbbH(BLp=O%O~h<4yFJLiSMNM zw$po8Fr}9#l|8pR5X6=<*VS%g8SAuTNN>7;dk(LjyWJywA@MmjK7nd(8QLE{V9Axu zD2ANVpd)-OYd1<)5^qMRE^+4@rsxW-|9e1C1Fvxw3 zs&l^XoQjaFw2F|M%~d33BdE!d`Pn?w#5^`nj6>^v<3;1=#ybW~IaQ9tl+6d0{5HNz z>L6%5;ysNM}((n@S@NgNJ;3|^*ll7*~5Uk$^D3jQOxym1a5QoPVq-?I-!hs<95*e zujHMgjJQM5lg@r?OM@8%saW(hdQQ*j39sXZf$GuC08SRNvn@PamO_07rkNMvJll3iUNvU&kBn`kTPX|5eP54DO; zcFuw|U&_|5YJjHm7=h>o5Y;6sfo^4Los=@oyu0C%t>!-p)Z`$3hDZGsO@@Dr!%~KS z81T#MXcn<~9eEgGr0|-aqqW@_@%f0g^>|zy z0Z*X?ZaQMSrur%_`Gs`!bL5x#4Bks$V#OycAUsu$D}=i6jk#`oxb!RYuzXk*!#BiNSzOuUdU z+rp!+qG^Mo4`&Zq8fgAteq?OHTU zW4JVZcn*Z>w_sMV$d~bwe>Pmp(4BTvly;7FNOht?*GJggOU*&St?(|AySQ_cBPD=E zcp}&XupSn4x|y(g%}UC-vr25T;-oS*(Ao@!z!ZrBf{|1MrQs_KB8!$pR8r2*Pb{33 z%&#ioQSwbGU?HTHSi_(%zl*n8HHB6LR~j=@H<1MGm z>chx{Zkk_Oq?^z*^lt!IjI14tY~WM&*)_^8j_naUF8GBL1|*8SVhlS9h;YT>KL9?t zhLtD;cwQv^AZ$~eA+HeVgs3Iqj*u&*mQ9JPkLm@b;11-UrF`UZR!;>%8p%`rJP=0B zot5pkDU*Vmc}+;T`SB9?Gdh-JO~GyAa%iK;d3oaat^yg9F~a%AAbSv~4pHw1lr~V1 z&^t{W5p6S^5;q@SCm%|IBM8&Q=_#tdfpof{qg+yOl@|oaLI|5lQY85!G)bW;+Ciw% zYEl#-)hqH;g`_vme{d7|V+=H`_~Y=Q08V0sBHUz@%It9{%skrRrDn(~yicDj6iJYDhMpR5wunS<|}b{u$nJmA6HxqYkN2rX{4n z89<;>MdF=kMKdk4WommD8kb5P=z(fuZd!_fUQvmQK6Sb33omE|(6{Vku1z9{l@5+N zu#yv$A|A+)2f!z~uoi@_Jka~qY0yKIh9ixzt___<;Etg1GT$`aMgtq1z8|1a;qo?I zWys0qw6o05*JYkZ6kGCjS&@Zqd8W+Bjbg?bK^@}dq#@(HjR(i4Wos;oZ_-m$m$})@|{#0 z`0)AM7pOs>${ND2^R*Oh0zr=l$q%O7LMmTzvTLcN`W_|cX|#C}1SU3x#-x-j&@P~6 zrEzBi^OACxvj#=c<364&m8-I+P^TDPndp2y56ap&PlXL)5e4!!`hE!wUD^kC`pzWU$w6E~!6wI*Tw9MiB z6tNe1EFx2rNY0K(Jcjg$6kTwrnVL~)3G`uAhCkAggocdMyqWkerZosKh+J50=ECf9 z5mUNupdFSZn!Zl`VHNA{;ZanuK>cCPM6@^mT`X2h)v~17%~~a~+@>dq#I&qSa!Jq9 pr;^`Him4^&Po@~FSxnqb%D2*z@yoQ2uIwyWmmgbrW?^ArMrQJ0#ay zE}5BK3lhke+C_jCdG0Iu81x_Xr_57J?~xf}b>_#b#_Ew0AFAD!>A&zxy>8dzVoIkz_d1`TiGPP}dSm{cdh+?&#AJlxlv zI1>kB4!;%n<#8>pf9aO`lgh-Odyg6Iho-f#zJIdXKl8C}?ajS4Usr4|&aiRSejFg8 zjq~E8vyVrrI}lH$(g>&+^+fmSe!G{Xxm5cP>UDwkp6crZ+oMr3?2Da`Wu`}}c_4O< zTcX<%gO)gLu2fE>%4B*)>!c-4N|lSR(Jm-+7f2E_RV90=?B+^3?d< zQTB8{HolT$8RgBIad8#nmtG7%G5px@uHgg2PmMDk;5+?_sh8pN-}iQhGV7leQP$th zrOv}H&bj-`4<76mX|AGF#=AwHqS74L|A zprjs*(zqpLs^RF4HG8*h-H9Rgl~{hWEq;aLyj_~QweWgd?Cil2nQy}DVf`B^a8wM( zXChK6Is*^Z?$#1N`}F;mxc~6i_e3upMfu(RmiX0P8^PUG7j(oD^EDcKo6w8KItN`!FiBPNGa4BdTi_ethTB#40voC&}_=83#qk zvdmJ1sf({6Om%Varf*yngid2=o^ocKY;0Ua-T0Bt&&JYt-830DYw8B=^~<|~X!gXW zwsLDcI-CiRp(HM%6m&IuwR|u za2kAzt2>zQ(&Bns5=ZY~Ro~K^P`k zl82#cVyM0k(YfB2`7(heY)mDmx^FfXIdyYVPT%2br;@kmQkFEagkA?3LY z%?nEQ6OET4JT7`YshS?Du1tH@6;6@#hmk$gSKQRXu$xAxg5hN%)m?aa51-zGU`QDy z(TBXr&#y1z-Co_`RVXO|@y*q5fYqET6nny-IjC{`N{!>|HEzzR!cF{zc9qibf7CYE zu(i!QQm@gG=ozwPY5KNa+dHc_Y4(Q@refKneuO?qaX|CL3`Gc3w@BP3(IW8%i8hH{ z67ntB-G zr7cqGH^T69kg;5`DN^skk=a~iuq)GanMHfNx3H>+m=z zNHtsnDM+R4vQIh0Nv(1tf6|;>$stAMoJ%gr@9Uo3l?0^XWVfcLyXWh#AK&kL^`~QF zWdpz8{_Y>GxpxfXuhiK8v(dPQC*DQD4Q>WTRlm)usoGZ6R&S^3sCTZK!`lk-?LxI+ z8vHHp@Z3|gTI9B8Z7_ylTW{3t*6)ToYg z^S)7e`z7W#mMdmjYP%1-T2uIZwN9(+|ETWZp?7=<1vk0%)cB1djl_`V6H81zGoG2x zEX;3n=V>u9mW(Ih{T$z)o6oH$_I-orFjwBaXJW2yyR*YwA>17%BPP8tj>Js4%?OgzS>$!)aGB}UKU&Yt<$kao|3EbYzQo`qJvSHQO-6GT- znoW!KEx8y4kx0iD@7=w(QoDX@Wp?G(^)x3VQE#Q?hKSm=kPwrWq;SKy6UAVq=BwDL z4Tz+LOyVMs`Cb@e60d*GoyV~?ce)VxdFjQnR*$%M{_c&X^Ehzf1|C1(m3|PP2ZtKf zW@w6?sg_V~5*bigrhIKD>hz-us!La!7Otj)`x*^vRO;wh$sFN4o4b2Gcoqe_>!~;axfp~q@zuWHXKjr zTaMc@pTLa#-uQ#{#Qc@nA17`nMiYQ+8jk_lCF3h|?c(EHlIs-+%Dv*A0%@v-Y6iQPx3qumH3MsOLB5t9_dYh<5A+cXFf5>tbSql zj)Ag}6nbw#ElT)iqh*S3#uX#8YiVr=yg*Z0A-r0DPVxrXC7rO_ncKcU0|r<00CfZy1p(R(*r@u|du z5*HGucRDHTnWpjZ6L}_aurs}eKPC>H2~HGXtc<&>(CWn~-13Bkp@br4HY{d-2s5ZG zE>l{m?KrHjHs1{iLt78rcE=B!WMe;X`{6WeyN}kU*_Io0y_jwLLBKXV7Dl4&2L3M0 zAQ65z%`U9<`5bTqTG`lCW=A8zl%dtlHCXtCdfn?trp(Mfb4=FpLWNyiVpxR#!G=Bzge*xufJfBrAA27kQ6*jljx8TD;ZyqpLVY9Cppm`774XrKP4>_gK z3$UK38*n-qKV-0hEE3#>}1krF*)Ri@|w`@wwi|RNsbd&F- zRy;#4X?;EH%yc2%OY7^?OqtpS1e6iRgH_Nc)*;DY7YcN2Dw8i;ZFwrnAb~3RF+XN-YI+y1 zcn30-F+F5!wQo>{4rb?P2GrEzmIfLvC@K*26>+8GKJw$;AK#-8C+>DS5sb>krD-~H%!O0c)6B5z0HpBh zp1%c?pisqqIPutz!g=;LFf)vAqidD3>L>(5Q=FpWG>X)AgCH%f0?#*uKs@0Pw9`WT zutQ?=B_B!WCOf{9S3PTQ5O)QKMo)MHquG^Z3tP%t1y zs2C;UEM9@0&O*F`{9I~Jw2&!uTI}m+ zTF4-t=5=707BXp9b72kZO!E+4E$kc?d0De0(VoMM@l_PY@pI-et8C^Gwv^4&=4q>l z@a263GLAWco+4@!=7g13t+F+4j#{JWA4PxB6bs;RAhaZD=I|&_-a-Z8AVN#{Tob<5 zN~|x9p3NXj2) zmNrM=O+f|t=iHwHa;q(9lHZ2rQ#|ZSpn`NO4*X_I0$L5kl@#T}nr4VnNLPw_N5b1| z>nJ;eU<@lpj!mDOhFhk~7p7T#nzg3ccIChDM=&1Chj_n+_0$2TWJKxrolN@BJg$na z_y3tiHTgh~mAbzC7i3x)6+ggB{E&(dP^5*9C)!<{Y=yeW2#JrWn5JTmih)jui)dCH zF;Bf@d&G}WqC(q4q?8j&!tFaR&0VqL{Lq@ZNP^yR#r9EV{IvfPtY`2G#S@B#B zzmml|l!ma49V8O3MMOOcu^S23Oe8?W9m+}d0A(}P_=X=K{DI9=&XL7Zcy4=N;@Sw@ zCYwSx7~Di1;D+7x8#^kNBF93kLYt3i$Q?qv2`^P zeto@(tUE@|;;k#Ex!w(RPR!Sbfcj3U@H7Pv1W2e01$Ib>A{E*lSkExP8>09oNC9I- z_ev_AdD?k)ci$d$u&EbXvBgbDy;>UVs9G9qMvxV(l*LDQ=^S#S+h}+~Lze)Xj^6pn z%8iAkno2vi{81#I)ym)(I`PPl4wCK0FM_Ht@>MA3vW)=HcUS{P{5Ii2VMbtxM*rA2GR_HQq{T9N52RBEP``w)iH-CL z8GX-EB~0iZ{y@92iC~`6sxK_lkhxwyLE42hz@~&Kz9tLYrnHJvz@Kw>fcVdZFA;kx z?hhkoAQfRL zrWB;CjVORDDb87^m8b=-Gaf-_LWDUUq852jtG4&uBtf=xT9VN`h#7n+Rj~6WIMY=W zw$VkR632}$#fCJA-uNnlq@tRx7CqQKdDF#epw*!|NYZ^JlXzws!GjgWsY7R$7Tk?k zSFptp2b31+0uPN?ndG#PolaV~b5F%YmArysS{iIO&FcfNmWQX67Qv#X7+TSTLzoqx z0&~xXSao2h$5lw7zSJp>>i?)tK?B8tG3n}agbPTrN9q1(|J`8&DCPF>-VnEj_>UOB zw)$`GH0+ZYrEd{;|Gp@q4%<-iH9*nmAJG4Fhp2835BPub{0ITy{K=Hkua5a@dOkJ< ztO1hymH;FP<{_|Ti1KJTEstUnWni~e^-Zu)d2FUiYno@LZ_qQxprY+p@&>X4bT1T> zp;Na?XCIrmQ#yQZ2A!L}W2%@lN|=MhVe5sNj_li+7ByWjiIXV)OzcU1>R#!+eT4h6 zhrun>%(8eFv|xS)FR@02N=bf6HN}+rzQGf#i}7;NbP$~$vlMN`T3y8>-otnBAhaN_ zSYE9h(c#61=+MyynRM|96%>n9kE$qc3s+iEIHNA~R4lJ9rL?tQAX*9|lr>c_-M2io nMpe)n)8dtO#Jhoa7531G>H3X=X8j+v%V&;VE*~u)FCYImJUU2v literal 0 HcmV?d00001 diff --git a/ultraled/utils/__pycache__/logger.cpython-38.pyc b/ultraled/utils/__pycache__/logger.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39df2674b3eaa7a66532c8ff58442909db1dc557 GIT binary patch literal 6198 zcmb7I&2QYs73Yv#E_b!N`iSGNe3VFO^P(Jow#+B)OH$R9fPKJxDq8U zDGwe~`{xbg-&C1?EL7gXoBav}H#m!pw*Im7Kz$cg=PgHHVi`E2HLyCGPUdJ>xEGS3pE(uZQO`@p?LUDu<4jnD-ccZ zdw8?+C<;~>KJsnS%XrAoUn2Bjl|L6cx?t_igzM>_ja!qa_O>+A{ZH3#F5O^nW7v@8fyki_I}78kC}ZgdhG^jP=(P*IFxfcR4U&xuc?dmqQXXEo zdgH1J+l>2^-_CbZgobj)}I8s1sG_FIEj~=2) z(@F^^K;*{2I5KR5lz`Wy?LJ^@O-tH(XG(az5i5^6516Q9ni5|Z^He1{h+U>M&iak9 z7;664%dmw>1=iofLA;6c@r|4f*b&KN7RDC6)@iEYKrpI=x~;|$MX&K3^*Ti1EpOZ6 zGmcp}SSm|nZk9SoYfX@Z`ZQfZH%ly`kXD{{1noe$R8**#qk@i0JWB2*`#?T9GJqhzC1 zt$KKVDmJoJMNuu?$-QMBb@Jw__gs@@uNrx8O*rWVnh&ndb5t_bg((;Sqv!{HD)6njBD%L~vGCDZZ_ zS!tAuFa@kIwW(OCR*zb_=;!PHNV4@bjUoOgS?J|qlBHrj4Fn$wKAEWjhuG;g zR=w+lpK5`YcQNb+z{XYYwaJXBHPZF1)!XCxZ7g@=$)!e2{`2?0cx!1AYEB9m);cyU zWb_a{b{W4!rGTS2LwG}!dB)(DsPr!jIdq!(GA(QLCS-SG1rb-$!LI#F4jC{k8>!RNj@A@MS{ zAuD4QDfAQ1P~$vZZjq+>@WjNo(Au;W-b)J>EC^}DG8K~>sV*bk+*Q085r%?62T<4m z3|3kHa86r{m$>sA7{ynJp$e0`U)j7;0AcW* zyI(#m74}8r)`>%>C>0K`p|yVBeA75A-T9y>@diW7`q`Q{3mly{xDOP;PDw zEHa-h`?PUbDZm|c>9oNYPr(fyR*Nd?P63C>7f-_{BE#~erzVrpV%MNQ; z?U|wmDNo$DXjRmoEz0~PBsx{N>f5+qQax-icVs~7-+pEu)(@S$RyaB0!#^-_9=oW6 z$LYdBdFDQxGIsJfIXP+PXTfpU1G(zrF@FwUwp|vFi~69!pFd*Yl%s8au)tp+e)r7+ z@*i2uxbs-I7xkHz#lcchI%0^+O2Fn9gU@g{ms8OTas?!Wo;{7$LmVb-oW){l!9DZ$ z0+EeD?wf!yF$j0k&gKZ9E~!eg+gnGBHiODEbQ;Ru(@~=avnNKzBv6ca8Mt*7AC1&u zAQIY_2oNGp?TUEQLx%1xcjGjmc~=NHsp&~Uc#sDwlADaah$hW|Tv!-jPbBVSt}LmY zN}HOc1qmOqn`;lr%l50S)~%bWd$n!8&imyRHF>pf{cwf%ofR(@tIK_>^?LY-t@N$t zElw04}82chV-M;B9_sv!B5u^GWp)B!m zeRrd6#(uvXr&$(e7n;vVYbV_i--d|NCRaLz8VbSAQI7>@DqS!SX&E3QN*=~}AT23qu>FaWn3Sa+#wHB*y7{DhKDG_f0zS{yG04OTe%Pa!;GX+uq}Ild=(?A7{TBL z;AOyPj=~Rkp&GA$;XqgBpv)U;?tEb>4jN5Ka|`{}r=gdm`yOTJ9)b*SKS=nxm#1EK zkaUiRZ7r2Xt!c9_siCsGBhsy~lgkPm>@Ey9;_VWQG2QP!e@(RvdFn{Vql@H4{B6lv zvXq2-xE3!$KR23=G$T$)7I9H+(g8*wau!!HpuIp6`BNv?KSPnF7Wf$Re)Y!)&y`$n z!)$+!`GT?+X(8m1<$VpX(k5$Y&uxTBNLUNe=`u&Tcn4#fWpSS7smT->DKW%{pxUll z730M(sAFR4WaV;E8H*mG1^YFcGfHm^%SE2JfLmy*-wCsdJ9yR1aF6kJN2Am#u3Dv> z!^Gdhx-c>~_%i~s_rqM}JQQw@GXUieZy*yIA731M>!C;bi-I`nhun+0@G;(IkikxK zbcV@vcguVCz~eea@B+o}dM?7C_W|rUhGX21G)sVX$dQ6|z1vvK??!QW+lw+U+X*{S zHwyWR7kHP^Iz3VG3*sy#p)`l5<$79r6^@TmwyACE)NPc4ta|5$hkS10b4L=X*vIQmn*9klmTg>d5ZZoBuQ!>pi9pe_pqtJ#Hof_$>&K5+QYca z{BXfKh`hYCJF~V3l5OvMq}!m|1u&3m_cyTxFU2fMo7J>1(`SSDT}h{f$kQ2KFI^HM z6)$->Oge&$6p?$q6p>QeRny2<4nw4CVZJMpsXUs_h4S;ODvj6roZ6aJLjToQmV>NA ze43f}%CeG!wB5{BQ5`z#W_DR&pDbOvcI~5UYm)8Xz#j$VUuX$g0j}s}8`j$wW4ox!rKSSe3&Bt&}A?TM7C8MCqSgM*X zu**h-u>c=PiJyx|!bO}>Q!P|ib2>D9^iXm(&im;9x1nd$U-)-Up2_NeOo$OCdyz;J z{AY^40}y4=2BI_>9{4b}jt^F}QDW=LCbG1f1L8vu&LvaBJv`cXJus19QzRhg)MpRz zUYdw)2y$$@QvLS@3Dket|G&98y6R|ne*K%j4{l!3w7*K?>|=tshc~&4hH8`ywO-kio~~%4XHcCQpJ_dl zn$$vXF=JYpX|zVG)c%YdkzSROkF~CI3>n(PfXHtCNck4t_wXhWnpB$-q0LB61fgVI z%XRcRdV?CmXlv9IdRyyg)Y{N6V?8BPaLfp1D$Dr|jaCK5z(?kMrj+*-V+;Nb4LaE` zZU))oFmU&?C`iLNat^u&_T!6EBQt7_j8}u?|I#4K#hcc>ej$ z{KLF=EUlYUfBECFw6ALvRvPnQP#FH>q^PEcW9Iou*DT0EVPfc@XUX#W6CK}Y-g!fU4R=zM^lH`tXrjD|>Mh#eJ zK(qR}k@|dP^|h7H=PRY1udiH=)FEqS23^1_dk1}ermZM{&euS#&NKm3^QLgb8vGrI zNezUPHpJQyLEAhc`=<19{}zt%y@flwb9Z+Ah`Mu^-dVbRIN~e`U$D~~aourTJokA& zxvjn=6-M3=&Y2F}YhjdjZ@V8x<1BT{vFoR7ah9Ir`*D`Oi7~-#WA3&6VeCt3A1iX- z4{&dm(w;INv7t8!>2lv6sxeoMEjj%t4gF!5_+iv{uRV_AA%y;Xz|sNZZW_D(csx7= z0BQIgu=$B(n+#aw(s&Xj{%AbJNGXoXE4;YS?{Inbxw_+DeD#OvKVJR*ZWr;5L#;F( z9-rXpu5o#P3L7$#u^+I)f>|*yOx)8!VFoN57UuJiO^OOT7~>SwB`l028x_?s35IdP zib~{j9#0AG)z)I3K$IY0v4r=7uC7z^i~!j5jyvp3*rUiv6qgx>n8s-ksHT|Rls!m$EZL~?#Y?4ce zX{~uHBV2>j@#Ejd>b*e5=*4ptA2wbU*#3kWN^&E`B|0KcNUA?Ep6dK(RL@PJJweC< zbw|hxZ+?z2P`3NVrj{enVV#A5p-rpQkO5(4%*fQnHPxamNL@<;=84 zMsDR5gvXU>9ibBeYDUzzN?9{lncF#%XYNotUrATNVz^~goSktBtHZPQAK!ai}rL4NwOiJ@3bkCEP?e4+eBkwr?T=QWLyvz z5zs0K*A*nv8kGUL1uHLMd>xW5AvkW7QmG{xO6ta0sbUQQ>lD9lB3EiFZCJ{l+h!}9 zIx6Hs!`w;R!agDov`>GPZ>AT;s@O!(wKj3O`DM&mYVKZsMYP2gaTz~nLpy6R&QhMF zKD~q>inTK$I^f(a@!2=-XfMgLzv2D!QUlR~ZtEK|_-~0VplgVh)C=f#6b_vQx+Myt z^|^iy1$fRI(KY7iDycxXwM0O?CRMDL7=IJFQd?=kQue8BIO}{{G|Tf>=++d1zfkC2 zcq6*EfbN#KQrf)4SZW4Tn{!mRPdPRfs>ZXQyfmJDUl=b9{#&s<-B9;sTbQ7~1l@th zZOhq2QdL@!cTQ=yEsIM35km5RBrO*~oVkAD{^26$yL&tX9!i6svNYq-nG$7z z?!4moy6^(9-Pq-E7SR&v5B5~jf2gv)ss_0H$D)!pZ{VxDS(vh>FC5Ps#WWi-&nxt3 zT-cr$!3{x~NUEcUOglP2~|R2)GeTU0V=Cxs;ok3rc@xJzo-k|3!DuY3C=9?|KCrj?hl4Ej)d%}8EsO$MzzpO6Cev1EAC|sUfNtV__kYaa)+^wxppdDMhRz>g z?=IQv+NC4h0?Dt*MwXbR3o5^;MzPoDK2;0O*dzZ)a#e2GcS_&9D~s1e4#{fCHf+a! Hze4^E$#hHK literal 0 HcmV?d00001 diff --git a/ultraled/utils/__pycache__/misc.cpython-38.pyc b/ultraled/utils/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70f1ca00b4e1daec5dcbc4ddb1dd9c685335a9c7 GIT binary patch literal 4357 zcmbVP-*4N<5#}x_iIU|XadSz0?FGD`ZmcvXUwdsHj5N(pUk**&L+l<9q8EbEQZ_A$ zly|8n7AO=*3?xBu!1bkl2m)9y2jpMkYoG29C~&x!_M4?-%RztwrNHIxXt_H(Gv7DE zKh4h83|xQv=f8KFi-z$p8cbg%2G?<`12n=AEHYfCpWNm8X}YFQ>P&iSmKr)Yfx!#=*PHhIMBmGz;T}K6hP|b}-B`+6}+#*rV<6d4$ z6WQ7!z2d$fd0K9@`ohoWr$n9~MGYp)(Ce)s>v;UQy4nrm?ft$VZ?C3-O1)Mhg4KI> zZm;&ERQgdMR{Lofsnu?%TFu^mKCgn*EB4ph%+N`-TX~Ju8QpRYvKHRfD2~Qp5}{_C@Dj-_+*+Nx=uEsMuVje2v+&h(!y?2vs`A5mZJ z`O@!tonU`U$MQy_oLjA=x35i=o4o{}57VHl zZT_6A^Mst|>=|73ioY;kuot{-5{3@!SH?>?Gh-PCZcLcB?3WQin=L;^{6n3f1v1-r z5Tor7&`5Uxniy;gxC8>SC{XW(2VV82A)+ymwDkf<%NAfEG^u{O9X@pkPfq5f{azGk zY+ZI>xiE?Sh+u}@CNt9=AJ_#*WpD&ZT4+kSFm5`xgSHQMuQ?kU$25a39IG!C;OVS9 zPLk-vE}#mqDxG*s_5-4d@!V+NQDhg`;KY{4$2s0^-H#M_qu9-H+WjaRo8y1(7sjoq zFM^4s^lQf}(TYch8tngG*_j zjlm}u3ddY>B-qf#5bba@Ksy^Kx;jyn5u@MpYv>uaJk~(A*;S%xg#r zGF6l#@(Ngm;=O?I8JsAuz-nYKP@G~f%1lH@H)K9j484&AqKt0ro(B56>v_?F2_ntvdpByt<|-w+Ihu;2&}2M#Y1 z3`2UbDC3WqsDRg2ggrF(Rx(q5m2y#~M5Et$Hnb))hb07uDQaQ{fkPDmB`PNBdIekk zBO9+R>PtqaBHf~fpjl!T_gq$fZHRfQGsME1(ufx;WTmU9L&T}91o`4bwq!szi?rBt zQK?!7Q;AWZsghN2o==DN_zfF7obGV>gZN+~8Lw+&-7{}Yz19z+f?z#6cy((hXmvn9 zl!O2>6eUU*DDNh!31~HHPnId5PiI9S$ALfpsHt=eZs2>ovWLtiofR-A<~%4iBN6x@ zi3=tu@~Y0Ax}t+7Ckh1eI{|)^eElJsZCK*#MHC0@5HJK#0pY*%JqzbTfiupp-dg}3 zlYvOy{8at|-FH)C?P`RAt(A0p*S2YGVwAGMjhZoVT4WuRF=U}B2OLqnqeJczsN+bWjrR2wCl55obo*E5B-DoJ@fb1#mDXX|OU8*QA==yoN=nrb7 zMbW3i1!aiJP1NU-D1}I@7x}F~YD%+qlQ0GwQWZjcMaIIZ+jDS=9RHM zm$Ofn-em_PyJJQDeiYsCRZ>jN@MbkoTtX7hKQpGys7}uJ^H+r zRMqOwR{g=FYiwQn>KvoJuD6din+Q*_#ec0DN#l~`VXy*jW`HI%~WHn0Et$dG+LdNNEKS0R@PU=wjU>fY($-Kt4V}+e>NHp zzDxMx6h;5I;1HX&`G7$L4zXF=7;w(sXPsv)mY5r^yudN0M=gw61JU-RoAP?kYg2650KeK!+jx00Y8 z6J?>9P?qo$|7#`>fxyF8PPc>3>Q>M1te#E+Y*#Z9!Rf0@i>G^GBK)|t z*9=K3sPQJaNMsKjbp4jpf`wYq%Ibi3Eml@c5Prt&-lRw#|38`CH>Flm7KtOtJY|8% z7O%9)HkPD}q$Ed|MFKm}d=zd5QkCBJ+dv{FTSJ3 zeIfR%aU>u?_b+bmqhD&Y{cg9_S#2c#sys~V9;Jfza-Yzp{hL6H#giyl&fq0p;blH) zxQ1goW(j@iVWdemx%;CwiKHFz12o`9e8)&(Wb#5{rf|h=AvwkGuZ3wC+sHs6Rec#4_K%AINmM9D>#|m|0S7u<}NF(F!11B*j)%Xaw zcrzHS&3l&#m!c0%@V)Ln*%q%PS*_r3txJ zt`!qk%8I%{M;;@PMaHIe2EXMik7~X^cG4i6XMr7)!LS0UdB%`4Xn#w}Jz)G4_Hy{N zwttO=-~;G-hYi^x`(*QDV_*)fgr`?3^NSujD->V zDBCtN$oH5cmtoy0JvKjo6#JN3+Z?p;(M-a^?l}Fwkea+C8fjKXzb{@A2Dd&d-*`D{XW!= zOkqSmvzh}O?(VPqkbB=!@aP-!AIX|-eBz?x)oe*&RJpCVg{b0pn%HY%TRBv|0;0+! zi?fvpk?N)WY%8-JuPVNw_~s*4AWwtn8~EZ13g%jHI+q{g6L7rexgLGY;m8cM3GARj zF1BIB)yxcpv0=z#0uD9mo9NGw1C9DN`e#H=_qoW!O>ZV-Li)5D8Wcm(vxZFGPT;Vk zz9|Yw2i&2lM?M2aP85b*6gQasjOLH)I((jn|=s~vr+XPi}90=hA4S{`5D(UiA_rsmo0v*%@x%g4O5#cIdYpLZoIyO$ zP>tsR?>q5SYH#G^cgYyRRyf9=nqr@GV7nP?8||dL%|52NCl(mWr>Wc-7v7K$aw&4p z4d(%uXOha!A#qYn4-J9>&%^o?x^-fTAu53qR)Lfo1>pag+K6yskWceFRq@P_Bi1>C zLh9@s8M3rMrNIc2Os0ixJ}8K12QEgd;<+Km%*jD9b;a{(@sCWr@Rfh(Y}N>!KuBtz&htgJT7H4z+)w?piTT-#@f5>-L)&?z}l&unmWmU=X&3o zMPOgq_GG(XT90mfyC@)&Z>~ixgea94+1Vz4glYY;*;&7rM92i=UOS-J z7|EbPK@|Zn=QoL;etvdV1h;#u{SsECoSN=i8`)LDke5M6UL|71*;!msngK<@hx}B3 z{xbDQuwEB=f8YVG<%;LLrMrF`x0QI_L#E&A%m;11A%dpAb?&9PvuDqOW4q%c|I(B- zuks|;09RIoUn4^Hb687S?dQo2myT0_>v9gGD z<$fB;P&DFJKcG8drw5tn@~DrDR^!d|vb$?&W5=M<0BmANMp4pO zP4;!dJ6Sqd%C@M1(<=#71=lRSYNHb*5N1OagS)upwQ!>X@yfnVgJ`N8X)TP^xSmBx zye1SYGmKW%*k&t+4qAPmHe8vI^IHefD+jH{3XD(XfVWD=r_RALST*)Hg9ePSxzUYq zVrEdBrZJMISZS8g&>YmJm4gk%HbCSuGB(&-C%HjgTv+Jdfl&>UFJWbQk(PHgI*r8g zIfBxb=v-JnPc09Ha(2xe3Z?u|QbsE&r%@(e0#q)09Z-3yX`HUH^}?R1VgtG%&U~AU zr>jq7ocVWyru>iyeN4qUqGo3vW;J`lS0xc_!l7_mLXMDSa%L2Fl?&BlogRs0twSrd zL#m*+m97l3PK0t=ER1WxU~QRT$|4(7HuNI5l(i9T#c~mYb%!z#Y#0(k0Hsef@*_)< z8#H}|3Q5Hyx3RHIgIQ#Wo4|;Fi-I8snLt)p#=nBx#K38X|9^70npOA-6S)a5BU72+ zRpcoX=)1^Y9A3e-%{9t^JD64GRpc-w%4E!A$ZSB(tfF)@b_G}!K516@3FMwR^h&_5 zQfZboMhR;W6c9t+2UiCYAxeVw5}`k!gI5?9A`#Lx$9oFB&=_r8QSb(XOcIHXjkLf{ za%gju$Qp-Srv!L&NpsssL=p!{gM};P1%31vkmxu#)?Ls-VBRSX84h$L5GeCs+4@x` zzf%%;4d4u^!dU81>-)5aGUJ+aefM3(X-3^G3f=M)815VLRR7R3-aAXzmu_BMUh>@@NGI_Iqfu&wx&u3CX!f+hd_hY*pn?`tX0t7n zL&JInvP&e*MExO%ZcF<)0V^SUIe~qaUt+w*b!@Hi<=PP$10efU z{DKO7w=FiKFbwp86YIwa`3#LA|7Drz<=sc2IsGVKs@!}8A+>;{6C8Xd1`r7 z%uqr1Bkknc;xu#n#MADhA8(umgv?$ToTpo}PUT6sT>(kVs^LsI6Yde`n0wpc{{>e$ Bq9*_V literal 0 HcmV?d00001 diff --git a/ultraled/utils/__pycache__/process.cpython-38.pyc b/ultraled/utils/__pycache__/process.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..961fbe1fa9c8ed16b557ddaf86dfe6fdc77b6dba GIT binary patch literal 6705 zcmc&&&2JmW72lct6h+CltjLx>6ee+$ux%(?iIdbdocJ?I>!^vHASvsBSaFsVC6dd` zu53$KU23O5(Yi$ov_*@epwN?m9^3vCy|tHOkL@8)PrdXKAV`1jEk#O{oKu02;Oua3bnPoQ>0O{rrTn^4EqG@eJ*Idwvv#K@$YSEtlz z)Sgf;swY(ewJG(IdP+Tw+A(!rJ)_Q`c3eHHX7T2_ja3v~L!`56JRF~#fcdeD% zhU$%lqDZ~B+Hs&e=M?JJa@c6EJJnU?UpSkZ)y`^F(!Acs!PCX>Dt^%{io{qp?pPag z!`!eoti(t}S0pka`-x3!(%7&Q8*lbBeM@xPuNUgh#;9k*5^NTd8lIbTk?=42DSV+ayR8&$cT#chGLyzOl z&*tY^LA$=T>bL8VK8h>VPz7^0uV0&6ZN}Pf25N3KZZxC0o_vHlwqQyuZ+8BBu1BktW0HXS{rC9TiKt;;gb4l_OyW+VlXzS^+rW3 zmL;il(ol|aAIS}wNd5<`uO$8fZ7A1Ue|kr`)vyU@R70(UYTO9hg_a-d25?r=q|I%z zhG+b9;czrbfpq^4@Lie~C|0^M4TgVc^gU0q7eXqCF1OxIKF(W4AVWCf8KO9x7cG5^>UOp1w>qiW zXve8tYleQjW$I}j;gt{6&k-I109zeRJ`cm-jUKqBWI^cX(cY&GVoBBpDeYiwu(Xyw zg$epJ6*E+j8jF%Qte>F{a@Ap@P3_xMrLKKNJBjF@4LK!s5zi_?XXir*>Vy6rH2V7z z06OecGpDH!z$FeqpQY@2ub*u-7;&@Nb%L`4c#A|ZoJT=`Zxmj*RKRnAEm2~T^aTD- z6Tu&!s6p zcWnI)^i@fLJ0Fs^kIO|g60llqBz7Vm$PFWNQbMURADL*IiSa;eII)#DiNp+tXE^;C zPGTu*#nbP_HfESKqn(Q#)Gewf&TV7MUVrtSEIxoS_YowkCv(aIaJV;O)OYsXtNiCR z{S>+uiY_zQE>Mu9?2nAT^mx^81=_Cwp`EZD1(n)rn@MELWu58~%uhYcnulT|{TwyT ztG8~XUZ@*&Fmjz8VV}x-{Q6DIXL^9jWe!~6Wu?gDsb9qCSFLva0ygnwdLnN~Pd`pi z>^_>1HLQYsPX#>VZA^H)9?oNQ4-t^o34%-nze6PvcMu6S%-BGvH8ur>-DL?#0(i@o z64QX+HRGdq03{=Frj4#m(A<<8ZtOfT5JB9;-V}-1d%6$I4KML%7Qb&zQ&2~InKm}f z1Rz#+O=gC4v!?*cs}27RA;`2DEcDEUu@r)(cCo za7P2|+^VN0l7S&Z=n^LQ_(jAd1|d*5asmW!LYyHQrt!yXKf&l%rG3_}bFg-cUgp&= zV}F`-pkJk8uO2n|pQgAvtdy6~7vh)k;gC`^jaE4xph6$AOdpWHM}vWgRE41k7!U+c z6ClL0NzlvU0+Jd)3NQs+0VVr^)NjTPKnMtTy-g%T+_JJ(cLX5J(Tdy)z|x$4uN{vf zzFEE03|Z+7a0+%m%~67m zNQJC7b=!(x?zsYGD3$hVtFy-N&Md`tBvwpq9xA$N4oS;Opei^lq$Ud9RmQ{e-eT(Z zs9C>8+a5~MT+AolCJbljmwAyFQz*yf6rQtEFJNpR2lq|=9!7UKhf*&Xj2KA3G}IE@ z0tmRpY@`erMQ~|zY0<#|27&qyLb_V4)se^o-1A~5xhFtw^>bQvl zCzEPxAigF}Jy7giPwdLw`A3)&@v&5l^LJKgN?>kOImt(jK6}tj-jdk*IHElS!MUa< z03!Ab_3aW+?iR2Y#bqW~QLqrZ_piT;l?VGzZ9W1WD8bl{1e}xq67NNn;TXaalX6Dr z+oDImxE6&5Dh;Ir>8qQXB@1W|f&)(Wk7GV?*v&A-1 zBJ;s6-q5(~Fc{3>h+G?|s-R8W!c?4&UydT2NYK8nHI8UkXflO$eT|CiRL~hi55Uiy z-DiRwGku395hJlLJRbZ|7Rc_qPS5V_^NUb#9M9sk@lQ9I|-RbhOfVmzC8j`YDIHcKzf66 z^oLAkl*9k*7kuCrASwWm;u!%3kg<*#pFk{R`8HbiD;(?gjX8A#DHraiI31fslC?9} z?E=Wd^EZVsJ}$>r*G^n05=k0P1a8gz0B7j)#=FKX;}6DdetrnvfePJ)J;G3~MAjlc zi01pkB*rz`D6Z4GpfIvf42C9)ZIZWGIGm7|!G>7U`jRClv1A=?`%#`Ilk@&uevFil z0}XDD>CvHl+QmQi3QW;4PeotGKj&gibc_%mkrICWg) z#FSW{`SLw6bhiy80}U9>*q)}pb8!3A3EJrhy-P_(#90`y+(EY*?6f$Ny1r6$5ug>` zOhqlV`Gr)h=xLQ|yA_)@~-b zgLk!0#V$zGz7jT>&L}5z_)Zk(R_(ZG%VfLJJ=PTdOBU#(O)vCXJ~B7gbR&ee2@uck|!^WlnzujjMl zup34t7Ki1g+#t2+x<1V}qKbdlZ#4bIW{}ulR1R(!1C_`6D#i!?ZCU(3e{>exCA;hq7n+r@t`q8UU^)4j+f7QIoHeO XCUcYCxHpzNLhYQD%X^-8%p3b3U&@0@ literal 0 HcmV?d00001 diff --git a/ultraled/utils/__pycache__/registry.cpython-38.pyc b/ultraled/utils/__pycache__/registry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37dbf5b2c50d4d1115044e49d9578137f035d577 GIT binary patch literal 2811 zcma)8O>YxN7@pZ(+lxs8;TzPCfl5s*fg`n2RYepLLP}|gK}mXOF3V`MJ&aY16V_Zi#v+(2eH8~?^amtr-@&v|X z@}!){SeG8Ooz46W)s9W72jQQbvAV#swWow0uZnapOhr#8k7KEf=!OSMT)QEXokuE4 z#X;EX#ohK*ktV|Qd%Z-bxDac3trzNaV5Rkq(q=(KeUl~!;vkXzjxx=9t|<|bR$;0{ zIMKMMn(ZbQcEV_XC+SXXX)NcrnoSWUQeE5WB+>rzvgi%cy`-D(k^Q$;R=>G*zqJ-@ ztlhc0xwY|~_(Bw3E-nOV@NQ|_if)^u zmfG|$0xgYYra;s zRJT3ohu!v4s!SSCewXg8-CpWe9YN6voDe4c`6BmOJMixm%} z_Zw*Y6CXKrV#Ke5?1)^YmGNLWzfwTR21km8;rk2XhtP@$rU#l*q_6f^Y<)cKleCLF;$rc)7&aL%4=q z5?+GfVY3vn3byVjf%+l^;;UlZ0ELBAQKG{%(E|V7B%R#YR7%JsLR3+-M1k9!p1U~Q zH(^^%j<+II#PXq#Qw&a{nG5KUec;4;3asRwj(~(i-U7;j&fH=>kFyPDdv>?qje}uktxQ{IJYYk;&$D)W5J}ZZxgOWLKk-3p|0jwo=d0 zjKt*f&(h3-!Y{>O$=GjO)pzD9o_mBt#i^{e6Ph?O`UozFOC)WTWbC1%m>~1~6(=-E z28`S;)-1Q2^NjXVU}nrIVive$c~yJLIrwMqN${>>&r~6ScvlX0-(Z`*1R*PEW)2%xWj75`H-g}69594c#;2J*b>axB3E||n%^+EfYRDe(m(DBp6s*m4 zyO|fKYFvscm>!rzo+FLQ_Y`!gCgWrftYf_}<48reyZ}~AfNz#-w!)EU$a_( zdKm?Q#?L%v*{jdfj1sHqTcfW~M@pwMF|me1N@=nmSf6_reZTJ4PuGw6uJ3Bv&1y1C zLj#<7p^o;l3M^5b%sy(}X^7hKs%I4Zu`4c0XUD;Uus(js^TfcvM r?cU^l_TK%?&B^PD^|h^yyQ}5G+icGK>v=G4Si>m^a}RuAmj3<$aHp5^ literal 0 HcmV?d00001 diff --git a/ultraled/utils/__pycache__/torchinterp1d.cpython-38.pyc b/ultraled/utils/__pycache__/torchinterp1d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..262f64365f588f96166ba7a3ec044ca191f1d2f3 GIT binary patch literal 4374 zcmb_fUvC@75#POkJkb5Cl}8nlF+Mv5#$E`~mtBByoRpM^aRh0tMQWZn(EQJO5{P z_JhU6GQ;!7Pyf~3yTaH%=wtq|@bL~_bqzo;!6VizetDC>&x9e&M@*Oj+c9>`W3y=q z{(#l37risx^RsHMw9-L%2B=-ZaI6u!Q}HH?6=Hjwn51$Gqv}6wBJZ z8>fNnUl$u~8@s0ujUBwdz^iHi85^^VXGUgb)`|6ieZe>CT$ehDJn*Hc+uBk=)Y1G< z^P#p5$sqlLNwUOOZ>{%&czZbTE4(hU_U1*~rGPvCT$K z%Er)T$2j3T=BQNYCHQCN*g9c%*oPne-Wa)~h1AX#cAR4_j9sqor6pk=^N;v!*4i~> zKU)ws{e5h7j8*o;$lRMOcC&?IL~xbuIKo+F0{Bljt%ItY*)&QQ(sJgc=P*+_=G*)- zTyxAv&Rzpth%@%6FTuG)oU@$H$awC>EyUUE?{IH)gBgIa|e;R@K|{t3-V= zSJ0S9Nps*?CZ6KlsQmpsay(DDpp(y4=CrjS4dLZ;qjSY*ym-RKRRrYXi9s_rnW&D+ z>3NK2mYVJ%s~sDoN>-W8EUvPSF>$dt}9B ztS@DYCwx@R_qf<|HDx;&fO|f3LjDXZPYn2cW#ks#8ehzu6JGf5xv}@-tU6vkVNc1H z^V%7pv|r4PU%~H8`cuuo=TfmkT2F9k#6;~Alhd#D7APYcFJpYhufkfv>IyqCpR&j7 zF@J1yOtB;`JaR|p;gcGocu8=`-5BMuw8|jQ*RW+RTPpO}=067xEU+Z(VhnBc6Hbx> zUPK($o}2P@a9=rLv6HTjubi+yeGa+um@1;0O2;&|Dbgmhwpc7?S6Egmbo>E)aQZ&} zW%ep@ApYmY!Z8QDcno`!buo62`3ih^nZaXa$oD1eU(L>u2Y<_c#(vFy!+wWqwDW3K z?eJ{rbK@D~EUThCE)V{Ox?1)o&%H1XeCb`k;}zv3iTpH7Vh{gmH}KxS|AB|<<=r3V zC8sgl_plo(uQQ0-IkpVafsEB0jX-$&el!S_mvlVeQ$b3D>u=ndZBm|}&d3oj5vWjN z-`*gQL$9AeAmwdRk=*=Gd?CX@B6q`h+f!jLjC>gly)^ND>8Fvu_4h}A^&q=yhyCFk z+k-Ieo?*NhSV5KchE>_~MoJI>VS1_~Ng4Sn z^+bphRps6s;ByO>#C8Rnc)lt;N<v z%-E8|t*s>Own!S!k42#x+ziZy_bbRP4-j%|r0CQDh32GECzR?morv0hr}{xRY;98iuydkS74^gK|c`jy88%>$+JUZLwtJP)-~$W-T&y0tc`GJR-F9p6Q~ z>P-N&PHdIQt2o^7J7#DjxZpUb@%4bE#@H-Ygz*H2CXNGhL!QSzSp}$Dy4+6UG~6B} z1EsC~5dEp07k=$*`=k@PSu2M9zA&{J#-dpYRjY#|L-YN-g_Y+(N1Nz!6$XlY$GFQc z@uMHzowlo!Zn$xioFA!M4PZAN6(#*Z)nAjOi?%yq+)1=8(qTW)X3$sK4pWp(Z6f!z zokyLn-c~A*bOOCAWg_*W4E6>%LsV;9`a)YM(J2W-AeE9jbh$);T3>DU{6lTSH$kLJ z*hGP>?V~^@N*if{?IvKWuqna1&a@z%mX%ud0 zR}K1o&{TR!A!1t4xwYlDck^q&1#Py|!_yGUot-;ZS0)O(ZP3V|Q zDq2*jnimrpZilfS<)5uW34W(p&PS~w;aCLaD$!{GG%F;MXcbvgxa|@xlJ6v4x+u}q zGXVCg#f!fxU*VUGDqiwtnP1_{+%>NnSH88(+BcR}``U4gudK4gukqST%P9ljGOzLr zgljrR8NW3fdrj9`=2t*dad9|KQRn*vGhYjEt1iH8G#Vw=jkEh-K($(OYr!U(P4i z$P{dwi_$V1S;83EsfGK6L3a&wkgty%-0P_KXS)mX9&UIxIx#DA(4pCax=4HPUS|D; zuAuHlo=E*ba%UbORTr59)>4A94E2Vsj#tTy||E*qsI zIQ*_c?o`-MK}}KEGlYqQz^HJ88Q(Zo9w7LtAXMgVx6^%fx@~XNOLQBiBGqc?a;t?l zZh#bR=}N1$H}IokMOHvQ4|ZgBN+%Bvg>E9`yELC$`4r@dG4-9Ic?7Fj{sXHST+)2q j&aEWZY4c40?Y=kZjq~*%BBDxNH0y?)`86 literal 0 HcmV?d00001 diff --git a/ultraled/utils/color_util.py b/ultraled/utils/color_util.py new file mode 100644 index 0000000..4740d5c --- /dev/null +++ b/ultraled/utils/color_util.py @@ -0,0 +1,208 @@ +import numpy as np +import torch + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + conversion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace conversion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) + + +def rgb2ycbcr_pt(img, y_only=False): + """Convert RGB images to YCbCr images (PyTorch version). + + It implements the ITU-R BT.601 conversion for standard-definition television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + Args: + img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float. + """ + if y_only: + weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 + else: + weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img) + bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias + + out_img = out_img / 255. + return out_img diff --git a/ultraled/utils/common.py b/ultraled/utils/common.py new file mode 100644 index 0000000..aef0606 --- /dev/null +++ b/ultraled/utils/common.py @@ -0,0 +1,17 @@ +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/ultraled/utils/diffjpeg.py b/ultraled/utils/diffjpeg.py new file mode 100644 index 0000000..32ca9f6 --- /dev/null +++ b/ultraled/utils/diffjpeg.py @@ -0,0 +1,515 @@ +""" +Modified from https://github.com/mlomnitz/DiffJPEG + +For images not divisible by 8 +https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343 +""" +import itertools +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +# ------------------------ utils ------------------------# +y_table = np.array( + [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56], + [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92], + [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], + dtype=np.float32).T +y_table = nn.Parameter(torch.from_numpy(y_table)) +c_table = np.empty((8, 8), dtype=np.float32) +c_table.fill(99) +c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T +c_table = nn.Parameter(torch.from_numpy(c_table)) + + +def diff_round(x): + """ Differentiable rounding function + """ + return torch.round(x) + (x - torch.round(x))**3 + + +def quality_to_factor(quality): + """ Calculate factor corresponding to quality + + Args: + quality(float): Quality for jpeg compression. + + Returns: + float: Compression factor. + """ + if quality < 50: + quality = 5000. / quality + else: + quality = 200. - quality * 2 + return quality / 100. + + +# ------------------------ compression ------------------------# +class RGB2YCbCrJpeg(nn.Module): + """ Converts RGB image to YCbCr + """ + + def __init__(self): + super(RGB2YCbCrJpeg, self).__init__() + matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]], + dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0., 128., 128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(Tensor): batch x 3 x height x width + + Returns: + Tensor: batch x height x width x 3 + """ + image = image.permute(0, 2, 3, 1) + result = torch.tensordot(image, self.matrix, dims=1) + self.shift + return result.view(image.shape) + + +class ChromaSubsampling(nn.Module): + """ Chroma subsampling on CbCr channels + """ + + def __init__(self): + super(ChromaSubsampling, self).__init__() + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + y(tensor): batch x height x width + cb(tensor): batch x height/2 x width/2 + cr(tensor): batch x height/2 x width/2 + """ + image_2 = image.permute(0, 3, 1, 2).clone() + cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cb = cb.permute(0, 2, 3, 1) + cr = cr.permute(0, 2, 3, 1) + return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) + + +class BlockSplitting(nn.Module): + """ Splitting image into patches + """ + + def __init__(self): + super(BlockSplitting, self).__init__() + self.k = 8 + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x h*w/64 x h x w + """ + height, _ = image.shape[1:3] + batch_size = image.shape[0] + image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) + + +class DCT8x8(nn.Module): + """ Discrete Cosine Transformation + """ + + def __init__(self): + super(DCT8x8, self).__init__() + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16) + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image - 128 + result = self.scale * torch.tensordot(image, self.tensor, dims=2) + result.view(image.shape) + return result + + +class YQuantize(nn.Module): + """ JPEG Quantization for Y channel + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(YQuantize, self).__init__() + self.rounding = rounding + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CQuantize(nn.Module): + """ JPEG Quantization for CbCr channels + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(CQuantize, self).__init__() + self.rounding = rounding + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CompressJpeg(nn.Module): + """Full JPEG compression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(CompressJpeg, self).__init__() + self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling()) + self.l2 = nn.Sequential(BlockSplitting(), DCT8x8()) + self.c_quantize = CQuantize(rounding=rounding) + self.y_quantize = YQuantize(rounding=rounding) + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x 3 x height x width + + Returns: + dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8. + """ + y, cb, cr = self.l1(image * 255) + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + comp = self.l2(components[k]) + if k in ('cb', 'cr'): + comp = self.c_quantize(comp, factor=factor) + else: + comp = self.y_quantize(comp, factor=factor) + + components[k] = comp + + return components['y'], components['cb'], components['cr'] + + +# ------------------------ decompression ------------------------# + + +class YDequantize(nn.Module): + """Dequantize Y channel + """ + + def __init__(self): + super(YDequantize, self).__init__() + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class CDequantize(nn.Module): + """Dequantize CbCr channel + """ + + def __init__(self): + super(CDequantize, self).__init__() + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class iDCT8x8(nn.Module): + """Inverse discrete Cosine Transformation + """ + + def __init__(self): + super(iDCT8x8, self).__init__() + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image * self.alpha + result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 + result.view(image.shape) + return result + + +class BlockMerging(nn.Module): + """Merge patches into image + """ + + def __init__(self): + super(BlockMerging, self).__init__() + + def forward(self, patches, height, width): + """ + Args: + patches(tensor) batch x height*width/64, height x width + height(int) + width(int) + + Returns: + Tensor: batch x height x width + """ + k = 8 + batch_size = patches.shape[0] + image_reshaped = patches.view(batch_size, height // k, width // k, k, k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, height, width) + + +class ChromaUpsampling(nn.Module): + """Upsample chroma layers + """ + + def __init__(self): + super(ChromaUpsampling, self).__init__() + + def forward(self, y, cb, cr): + """ + Args: + y(tensor): y channel image + cb(tensor): cb channel + cr(tensor): cr channel + + Returns: + Tensor: batch x height x width x 3 + """ + + def repeat(x, k=2): + height, width = x.shape[1:3] + x = x.unsqueeze(-1) + x = x.repeat(1, 1, k, k) + x = x.view(-1, height * k, width * k) + return x + + cb = repeat(cb) + cr = repeat(cr) + return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) + + +class YCbCr2RGBJpeg(nn.Module): + """Converts YCbCr image to RGB JPEG + """ + + def __init__(self): + super(YCbCr2RGBJpeg, self).__init__() + + matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + Tensor: batch x 3 x height x width + """ + result = torch.tensordot(image + self.shift, self.matrix, dims=1) + return result.view(image.shape).permute(0, 3, 1, 2) + + +class DeCompressJpeg(nn.Module): + """Full JPEG decompression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(DeCompressJpeg, self).__init__() + self.c_dequantize = CDequantize() + self.y_dequantize = YDequantize() + self.idct = iDCT8x8() + self.merging = BlockMerging() + self.chroma = ChromaUpsampling() + self.colors = YCbCr2RGBJpeg() + + def forward(self, y, cb, cr, imgh, imgw, factor=1): + """ + Args: + compressed(dict(tensor)): batch x h*w/64 x 8 x 8 + imgh(int) + imgw(int) + factor(float) + + Returns: + Tensor: batch x 3 x height x width + """ + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + if k in ('cb', 'cr'): + comp = self.c_dequantize(components[k], factor=factor) + height, width = int(imgh / 2), int(imgw / 2) + else: + comp = self.y_dequantize(components[k], factor=factor) + height, width = imgh, imgw + comp = self.idct(comp) + components[k] = self.merging(comp, height, width) + # + image = self.chroma(components['y'], components['cb'], components['cr']) + image = self.colors(image) + + image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image)) + return image / 255 + + +# ------------------------ main DiffJPEG ------------------------ # + + +class DiffJPEG(nn.Module): + """This JPEG algorithm result is slightly different from cv2. + DiffJPEG supports batch processing. + + Args: + differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round + """ + + def __init__(self, differentiable=True): + super(DiffJPEG, self).__init__() + if differentiable: + rounding = diff_round + else: + rounding = torch.round + + self.compress = CompressJpeg(rounding=rounding) + self.decompress = DeCompressJpeg(rounding=rounding) + + def forward(self, x, quality): + """ + Args: + x (Tensor): Input image, bchw, rgb, [0, 1] + quality(float): Quality factor for jpeg compression scheme. + """ + factor = quality + if isinstance(factor, (int, float)): + factor = quality_to_factor(factor) + else: + for i in range(factor.size(0)): + factor[i] = quality_to_factor(factor[i]) + h, w = x.size()[-2:] + h_pad, w_pad = 0, 0 + # why should use 16 + if h % 16 != 0: + h_pad = 16 - h % 16 + if w % 16 != 0: + w_pad = 16 - w % 16 + x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0) + + y, cb, cr = self.compress(x, factor=factor) + recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor) + recovered = recovered[:, :, 0:h, 0:w] + return recovered + + +if __name__ == '__main__': + import cv2 + + from ultraled.utils import img2tensor, tensor2img + + img_gt = cv2.imread('test.png') / 255. + + # -------------- cv2 -------------- # + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20] + _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param) + img_lq = np.float32(cv2.imdecode(encimg, 1)) + cv2.imwrite('cv2_JPEG_20.png', img_lq) + + # -------------- DiffJPEG -------------- # + jpeger = DiffJPEG(differentiable=False).cuda() + img_gt = img2tensor(img_gt) + img_gt = torch.stack([img_gt, img_gt]).cuda() + quality = img_gt.new_tensor([20, 40]) + out = jpeger(img_gt, quality=quality) + + cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0])) + cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1])) diff --git a/ultraled/utils/dist_util.py b/ultraled/utils/dist_util.py new file mode 100644 index 0000000..0fab887 --- /dev/null +++ b/ultraled/utils/dist_util.py @@ -0,0 +1,82 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/ultraled/utils/download_util.py b/ultraled/utils/download_util.py new file mode 100644 index 0000000..6adda71 --- /dev/null +++ b/ultraled/utils/download_util.py @@ -0,0 +1,99 @@ +import math +import os +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm +from urllib.parse import urlparse + +from .misc import sizeof_fmt + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + if 'Content-Range' in response_file_size.headers: + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file diff --git a/ultraled/utils/file_client.py b/ultraled/utils/file_client.py new file mode 100644 index 0000000..89d83ab --- /dev/null +++ b/ultraled/utils/file_client.py @@ -0,0 +1,167 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing different lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/ultraled/utils/flow_util.py b/ultraled/utils/flow_util.py new file mode 100644 index 0000000..3d7180b --- /dev/null +++ b/ultraled/utils/flow_util.py @@ -0,0 +1,170 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 +import cv2 +import numpy as np +import os + + +def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): + """Read an optical flow map. + + Args: + flow_path (ndarray or str): Flow path. + quantize (bool): whether to read quantized pair, if set to True, + remaining args will be passed to :func:`dequantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + + Returns: + ndarray: Optical flow represented as a (h, w, 2) numpy array + """ + if quantize: + assert concat_axis in [0, 1] + cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) + if cat_flow.ndim != 2: + raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') + assert cat_flow.shape[concat_axis] % 2 == 0 + dx, dy = np.split(cat_flow, 2, axis=concat_axis) + flow = dequantize_flow(dx, dy, *args, **kwargs) + else: + with open(flow_path, 'rb') as f: + try: + header = f.read(4).decode('utf-8') + except Exception: + raise IOError(f'Invalid flow file: {flow_path}') + else: + if header != 'PIEH': + raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH') + + w = np.fromfile(f, np.int32, 1).squeeze() + h = np.fromfile(f, np.int32, 1).squeeze() + flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) + + return flow.astype(np.float32) + + +def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): + """Write optical flow to file. + + If the flow is not quantized, it will be saved as a .flo file losslessly, + otherwise a jpeg image which is lossy but of much smaller size. (dx and dy + will be concatenated horizontally into a single image if quantize is True.) + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + filename (str): Output filepath. + quantize (bool): Whether to quantize the flow and save it to 2 jpeg + images. If set to True, remaining args will be passed to + :func:`quantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + """ + if not quantize: + with open(filename, 'wb') as f: + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + f.flush() + else: + assert concat_axis in [0, 1] + dx, dy = quantize_flow(flow, *args, **kwargs) + dxdy = np.concatenate((dx, dy), axis=concat_axis) + os.makedirs(os.path.dirname(filename), exist_ok=True) + cv2.imwrite(filename, dxdy) + + +def quantize_flow(flow, max_val=0.02, norm=True): + """Quantize flow to [0, 255]. + + After this step, the size of flow will be much smaller, and can be + dumped as jpeg images. + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + max_val (float): Maximum value of flow, values beyond + [-max_val, max_val] will be truncated. + norm (bool): Whether to divide flow values by image width/height. + + Returns: + tuple[ndarray]: Quantized dx and dy. + """ + h, w, _ = flow.shape + dx = flow[..., 0] + dy = flow[..., 1] + if norm: + dx = dx / w # avoid inplace operations + dy = dy / h + # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. + flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] + return tuple(flow_comps) + + +def dequantize_flow(dx, dy, max_val=0.02, denorm=True): + """Recover from quantized flow. + + Args: + dx (ndarray): Quantized dx. + dy (ndarray): Quantized dy. + max_val (float): Maximum value used when quantizing. + denorm (bool): Whether to multiply flow values with width/height. + + Returns: + ndarray: Dequantized flow. + """ + assert dx.shape == dy.shape + assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) + + dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] + + if denorm: + dx *= dx.shape[1] + dy *= dx.shape[0] + flow = np.dstack((dx, dy)) + return flow + + +def quantize(arr, min_val, max_val, levels, dtype=np.int64): + """Quantize an array of (-inf, inf) to [0, levels-1]. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the quantized array. + + Returns: + tuple: Quantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError(f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + arr = np.clip(arr, min_val, max_val) - min_val + quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) + + return quantized_arr + + +def dequantize(arr, min_val, max_val, levels, dtype=np.float64): + """Dequantize an array. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the dequantized array. + + Returns: + tuple: Dequantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError(f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val + + return dequantized_arr diff --git a/ultraled/utils/img_process_util.py b/ultraled/utils/img_process_util.py new file mode 100644 index 0000000..52e02f0 --- /dev/null +++ b/ultraled/utils/img_process_util.py @@ -0,0 +1,83 @@ +import cv2 +import numpy as np +import torch +from torch.nn import functional as F + + +def filter2D(img, kernel): + """PyTorch version of cv2.filter2D + + Args: + img (Tensor): (b, c, h, w) + kernel (Tensor): (b, k, k) + """ + k = kernel.size(-1) + b, c, h, w = img.size() + if k % 2 == 1: + img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') + else: + raise ValueError('Wrong kernel size') + + ph, pw = img.size()[-2:] + + if kernel.size(0) == 1: + # apply the same kernel to all batch images + img = img.view(b * c, 1, ph, pw) + kernel = kernel.view(1, 1, k, k) + return F.conv2d(img, kernel, padding=0).view(b, c, h, w) + else: + img = img.view(1, b * c, ph, pw) + kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) + return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) + + +def usm_sharp(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. + + Input image: I; Blurry image: B. + 1. sharp = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * sharp + (1 - Mask) * I + + + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + sharp = img + weight * residual + sharp = np.clip(sharp, 0, 1) + return soft_mask * sharp + (1 - soft_mask) * img + + +class USMSharp(torch.nn.Module): + + def __init__(self, radius=50, sigma=0): + super(USMSharp, self).__init__() + if radius % 2 == 0: + radius += 1 + self.radius = radius + kernel = cv2.getGaussianKernel(radius, sigma) + kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) + self.register_buffer('kernel', kernel) + + def forward(self, img, weight=0.5, threshold=10): + blur = filter2D(img, self.kernel) + residual = img - blur + + mask = torch.abs(residual) * 255 > threshold + mask = mask.float() + soft_mask = filter2D(mask, self.kernel) + sharp = img + weight * residual + sharp = torch.clip(sharp, 0, 1) + return soft_mask * sharp + (1 - soft_mask) * img diff --git a/ultraled/utils/img_util.py b/ultraled/utils/img_util.py new file mode 100644 index 0000000..3a5f1da --- /dev/null +++ b/ultraled/utils/img_util.py @@ -0,0 +1,172 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + ok = cv2.imwrite(file_path, img, params) + if not ok: + raise IOError('Failed in writing images.') + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] diff --git a/ultraled/utils/lmdb_util.py b/ultraled/utils/lmdb_util.py new file mode 100644 index 0000000..e0a10f6 --- /dev/null +++ b/ultraled/utils/lmdb_util.py @@ -0,0 +1,196 @@ +import cv2 +import lmdb +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def make_lmdb_from_imgs(data_path, + lmdb_path, + img_path_list, + keys, + batch=5000, + compress_level=1, + multiprocessing_read=False, + n_thread=40, + map_size=None): + """Make lmdb from images. + + Contents of lmdb. The file structure is: + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records 1)image name (with extension), + 2)image shape, and 3)compression level, separated by a white space. + + For example, the meta information could be: + `000_00000000.png (720,1280,3) 1`, which means: + 1) image name (with extension): 000_00000000.png; + 2) image shape: (720,1280,3); + 3) compression level: 1 + + We use the image name without extension as the lmdb key. + + If `multiprocessing_read` is True, it will read all the images to memory + using multiprocessing. Thus, your server needs to have enough memory. + + Args: + data_path (str): Data path for reading images. + lmdb_path (str): Lmdb save path. + img_path_list (str): Image path list. + keys (str): Used for lmdb keys. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + multiprocessing_read (bool): Whether use multiprocessing to read all + the images to memory. Default: False. + n_thread (int): For multiprocessing. + map_size (int | None): Map size for lmdb env. If None, use the + estimated size from images. Default: None + """ + + assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' + f'but got {len(img_path_list)} and {len(keys)}') + print(f'Create lmdb for {data_path}, save to {lmdb_path}...') + print(f'Totoal images: {len(img_path_list)}') + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + if multiprocessing_read: + # read all the images to memory (multiprocessing) + dataset = {} # use dict to keep the order for multiprocessing + shapes = {} + print(f'Read images with multiprocessing, #thread: {n_thread} ...') + pbar = tqdm(total=len(img_path_list), unit='image') + + def callback(arg): + """get the image data and update pbar.""" + key, dataset[key], shapes[key] = arg + pbar.update(1) + pbar.set_description(f'Read {key}') + + pool = Pool(n_thread) + for path, key in zip(img_path_list, keys): + pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) + pool.close() + pool.join() + pbar.close() + print(f'Finish reading {len(img_path_list)} images.') + + # create lmdb environment + if map_size is None: + # obtain data size for one image + img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + data_size_per_img = img_byte.nbytes + print('Data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(img_path_list) + map_size = data_size * 10 + + env = lmdb.open(lmdb_path, map_size=map_size) + + # write data to lmdb + pbar = tqdm(total=len(img_path_list), unit='chunk') + txn = env.begin(write=True) + txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + for idx, (path, key) in enumerate(zip(img_path_list, keys)): + pbar.update(1) + pbar.set_description(f'Write {key}') + key_byte = key.encode('ascii') + if multiprocessing_read: + img_byte = dataset[key] + h, w, c = shapes[key] + else: + _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) + h, w, c = img_shape + + txn.put(key_byte, img_byte) + # write meta information + txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') + if idx % batch == 0: + txn.commit() + txn = env.begin(write=True) + pbar.close() + txn.commit() + env.close() + txt_file.close() + print('\nFinish writing lmdb.') + + +def read_img_worker(path, key, compress_level): + """Read image worker. + + Args: + path (str): Image path. + key (str): Image key. + compress_level (int): Compress level when encoding images. + + Returns: + str: Image key. + byte: Image byte. + tuple[int]: Image shape. + """ + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img.ndim == 2: + h, w = img.shape + c = 1 + else: + h, w, c = img.shape + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + return (key, img_byte, (h, w, c)) + + +class LmdbMaker(): + """LMDB Maker. + + Args: + lmdb_path (str): Lmdb save path. + map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + self.lmdb_path = lmdb_path + self.batch = batch + self.compress_level = compress_level + self.env = lmdb.open(lmdb_path, map_size=map_size) + self.txn = self.env.begin(write=True) + self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + self.counter = 0 + + def put(self, img_byte, key, img_shape): + self.counter += 1 + key_byte = key.encode('ascii') + self.txn.put(key_byte, img_byte) + # write meta information + h, w, c = img_shape + self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') + if self.counter % self.batch == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + + def close(self): + self.txn.commit() + self.env.close() + self.txt_file.close() diff --git a/ultraled/utils/logger.py b/ultraled/utils/logger.py new file mode 100644 index 0000000..dcbedf1 --- /dev/null +++ b/ultraled/utils/logger.py @@ -0,0 +1,217 @@ +import datetime +import logging +import time + +from .dist_util import get_dist_info, master_only + +initialized_logger = {} + + +class AvgTimer(): + + def __init__(self, window=200): + self.window = window # average window + self.current_time = 0 + self.total_time = 0 + self.count = 0 + self.avg_time = 0 + self.start() + + def start(self): + self.start_time = self.tic = time.time() + + def record(self): + self.count += 1 + self.toc = time.time() + self.current_time = self.toc - self.tic + self.total_time += self.current_time + # calculate average time + self.avg_time = self.total_time / self.count + + # reset + if self.count > self.window: + self.count = 0 + self.total_time = 0 + + self.tic = time.time() + + def get_current_time(self): + return self.current_time + + def get_avg_time(self): + return self.avg_time + + +class MessageLogger(): + """Message logger for printing. + + Args: + opt (dict): Config. It contains the following keys: + name (str): Exp name. + logger (dict): Contains 'print_freq' (str) for logger interval. + train (dict): Contains 'total_iter' (int) for total iters. + use_tb_logger (bool): Use tensorboard logger. + start_iter (int): Start iter. Default: 1. + tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. + """ + + def __init__(self, opt, start_iter=1, tb_logger=None): + self.exp_name = opt['name'] + self.interval = opt['logger']['print_freq'] + self.start_iter = start_iter + self.max_iters = opt['train']['total_iter'] + self.use_tb_logger = opt['logger']['use_tb_logger'] + self.tb_logger = tb_logger + self.start_time = time.time() + self.logger = get_root_logger() + + def reset_start_time(self): + self.start_time = time.time() + + @master_only + def __call__(self, log_vars): + """Format logging message. + + Args: + log_vars (dict): It contains the following keys: + epoch (int): Epoch number. + iter (int): Current iter. + lrs (list): List for learning rates. + + time (float): Iter time. + data_time (float): Data time for each iter. + """ + # epoch, iter, learning rates + epoch = log_vars.pop('epoch') + current_iter = log_vars.pop('iter') + lrs = log_vars.pop('lrs') + + message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(') + for v in lrs: + message += f'{v:.3e},' + message += ')] ' + + # time and estimated time + if 'time' in log_vars.keys(): + iter_time = log_vars.pop('time') + data_time = log_vars.pop('data_time') + + total_time = time.time() - self.start_time + time_sec_avg = total_time / (current_iter - self.start_iter + 1) + eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + message += f'[eta: {eta_str}, ' + message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' + + # other items, especially losses + for k, v in log_vars.items(): + if isinstance(v, str): + message += f'{k}: {v} ' + continue + else: + message += f'{k}: {v:.4e} ' + # tensorboard logger + if self.use_tb_logger and 'debug' not in self.exp_name: + if k.startswith('l_'): + self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) + else: + self.tb_logger.add_scalar(k, v, current_iter) + self.logger.info(message) + + +@master_only +def init_tb_logger(log_dir): + from torch.utils.tensorboard import SummaryWriter + tb_logger = SummaryWriter(log_dir=log_dir) + return tb_logger + + +@master_only +def init_wandb_logger(opt): + """We now only use wandb to sync tensorboard log.""" + import wandb + logger = get_root_logger() + + project = opt['logger']['wandb']['project'] + resume_id = opt['logger']['wandb'].get('resume_id') + if resume_id: + wandb_id = resume_id + resume = 'allow' + logger.warning(f'Resume wandb logger with id={wandb_id}.') + else: + wandb_id = wandb.util.generate_id() + resume = 'never' + + wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) + + logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') + + +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + elif log_file is not None: + logger.setLevel(log_level) + # add file handler + file_handler = logging.FileHandler(log_file, 'w') + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +def get_env_info(): + """Get environment information. + + Currently, only log the software version. + """ + import torch + import torchvision + + # from basicsr.version import __version__ +# msg = r""" +# ____ _ _____ ____ +# / __ ) ____ _ _____ (_)_____/ ___/ / __ \ +# / __ |/ __ `// ___// // ___/\__ \ / /_/ / +# / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ +# /_____/ \__,_//____//_/ \___//____//_/ |_| +# ______ __ __ __ __ +# / ____/____ ____ ____/ / / / __ __ _____ / /__ / / +# / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / +# / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ +# \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) +# """ +# msg += ('\nVersion Information: ' +# f'\n\tBasicSR: {__version__}' +# f'\n\tPyTorch: {torch.__version__}' +# f'\n\tTorchVision: {torchvision.__version__}') +# return msg diff --git a/ultraled/utils/matlab_functions.py b/ultraled/utils/matlab_functions.py new file mode 100644 index 0000000..a201f79 --- /dev/null +++ b/ultraled/utils/matlab_functions.py @@ -0,0 +1,178 @@ +import math +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( + out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + squeeze_flag = False + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + if img.ndim == 2: + img = img[:, :, None] + squeeze_flag = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + if img.ndim == 2: + img = img.unsqueeze(0) + squeeze_flag = True + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width, + antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width, + antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) + + if squeeze_flag: + out_2 = out_2.squeeze(0) + if numpy_type: + out_2 = out_2.numpy() + if not squeeze_flag: + out_2 = out_2.transpose(1, 2, 0) + + return out_2 diff --git a/ultraled/utils/misc.py b/ultraled/utils/misc.py new file mode 100644 index 0000000..728fef8 --- /dev/null +++ b/ultraled/utils/misc.py @@ -0,0 +1,141 @@ +import numpy as np +import os +import random +import time +import torch +from os import path as osp + +from .dist_util import master_only + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + if osp.exists(path): + new_name = path + '_archived_' + get_time_str() + print(f'Path already exists. Rename it to {new_name}', flush=True) + os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt['path'].copy() + if opt['is_train']: + mkdir_and_rename(path_opt.pop('experiments_root')) + else: + mkdir_and_rename(path_opt.pop('results_root')) + for key, path in path_opt.items(): + if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): + continue + else: + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + print('pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or (network + not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + print(f"Set {name} to {opt['path'][name]}") + + # change param_key to params in resume + param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] + for param_key in param_keys: + if opt['path'][param_key] == 'params_ema': + opt['path'][param_key] = 'params' + print(f'Set {param_key} to params') + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formatted file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/ultraled/utils/options.py b/ultraled/utils/options.py new file mode 100644 index 0000000..09b749b --- /dev/null +++ b/ultraled/utils/options.py @@ -0,0 +1,211 @@ +import argparse +import random +import torch +import yaml +from collections import OrderedDict +from os import path as osp +import os +from ultraled.utils import set_random_seed +from ultraled.utils.dist_util import get_dist_info, init_dist, master_only + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + tuple: yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def yaml_load(f): + """Load yaml file or string. + + Args: + f (str): File path or a python string. + + Returns: + dict: Loaded dict. + """ + if os.path.isfile(f): + with open(f, 'r') as f: + return yaml.load(f, Loader=ordered_yaml()[0]) + else: + return yaml.load(f, Loader=ordered_yaml()[0]) + + + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = '\n' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_level * 2) + k + ':[' + msg += dict2str(v, indent_level + 1) + msg += ' ' * (indent_level * 2) + ']\n' + else: + msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' + return msg + + +def _postprocess_yml_value(value): + # None + if value == '~' or value.lower() == 'none': + return None + # bool + if value.lower() == 'true': + return True + elif value.lower() == 'false': + return False + # !!float number + if value.startswith('!!float'): + return float(value.replace('!!float', '')) + # number + if value.isdigit(): + return int(value) + elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: + return float(value) + # list + if value.startswith('['): + return eval(value) + # str + return value + + +def parse_options(root_path, is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') + parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') + parser.add_argument('--auto_resume', action='store_true') + parser.add_argument('--debug', action='store_true') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') + args = parser.parse_args() + + # parse yml to dict + with open(args.opt, mode='r') as f: + opt = yaml.load(f, Loader=ordered_yaml()[0]) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + # force to update yml options + if args.force_yml is not None: + for entry in args.force_yml: + # now do not support creating new keys + keys, value = entry.split('=') + keys, value = keys.strip(), value.strip() + value = _postprocess_yml_value(value) + eval_str = 'opt' + for key in keys.split(':'): + eval_str += f'["{key}"]' + eval_str += '=value' + # using exec function + exec(eval_str) + + opt['auto_resume'] = args.auto_resume + opt['is_train'] = is_train + + # debug setting + if args.debug and not opt['name'].startswith('debug'): + opt['name'] = 'debug_' + opt['name'] + + if opt['num_gpu'] == 'auto': + opt['num_gpu'] = torch.cuda.device_count() + + # datasets + for phase, dataset in opt['datasets'].items(): + # for multiple datasets, e.g., val_1, val_2; test_1, test_2 + phase = phase.split('_')[0] + dataset['phase'] = phase + if 'scale' in opt: + dataset['scale'] = opt['scale'] + if dataset.get('dataroot_gt') is not None: + dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) + if dataset.get('dataroot_lq') is not None: + dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) + + # paths + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) + + if is_train: + experiments_root = osp.join(root_path, 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_states'] = osp.join(experiments_root, 'training_states') + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, 'visualization') + + # change some options for debug mode + if 'debug' in opt['name']: + if 'val' in opt: + opt['val']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + results_root = osp.join(root_path, 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + opt['path']['visualization'] = osp.join(results_root, 'visualization') + + return opt, args + + +@master_only +def copy_opt_file(opt_file, experiments_root): + # copy the yml file to the experiment root + import sys + import time + from shutil import copyfile + cmd = ' '.join(sys.argv) + filename = osp.join(experiments_root, osp.basename(opt_file)) + copyfile(opt_file, filename) + + with open(filename, 'r+') as f: + lines = f.readlines() + lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') + f.seek(0) + f.writelines(lines) diff --git a/ultraled/utils/process.py b/ultraled/utils/process.py new file mode 100644 index 0000000..32e302b --- /dev/null +++ b/ultraled/utils/process.py @@ -0,0 +1,223 @@ +"""Forward processing of raw data to sRGB images. + +Unprocessing Images for Learned Raw Denoising +http://timothybrooks.com/tech/unprocessing +""" + +import numpy as np +import torch +from ultraled.utils.torchinterp1d import Interp1d +from os.path import join + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +def apply_gains(bayer_images, wbs): + """Applies white balance to a batch of Bayer images.""" + N, C, _, _ = bayer_images.shape + outs = bayer_images * wbs.view(N, C, 1, 1) + return outs + +def apply_ccms(images, ccms): + """Applies color correction matrices.""" + images = images.permute( + 0, 2, 3, 1) # Permute the image tensor to BxHxWxC format from BxCxHxW format + images = images[:, :, :, None, :] + ccms = ccms[:, None, None, :, :] + outs = torch.sum(images * ccms, dim=-1) + # Re-Permute the tensor back to BxCxHxW format + outs = outs.permute(0, 3, 1, 2) + return outs + + +def gamma_compression(images, gamma=2.2): + """Converts from linear to gamma space.""" + outs = torch.clamp(images, min=1e-8) ** (1 / gamma) + # outs = (1 + gamma[0]) * np.power(images, 1.0/gamma[1]) - gamma[0] + gamma[2]*images + outs = torch.clamp((outs*255).int(), min=0, max=255).float() / 255 + return outs + + +def gamma_compression_grad(images, gamma=2.2): + """Converts from linear to gamma space.""" + outs = torch.clamp(images, min=1e-8) ** (1 / gamma) + # outs = (1 + gamma[0]) * np.power(images, 1.0/gamma[1]) - gamma[0] + gamma[2]*images + return outs + + +def binning(bayer_images): + """RGBG -> RGB""" + lin_rgb = torch.stack([ + bayer_images[:,0,...], + torch.mean(bayer_images[:, [1,3], ...], dim=1), + bayer_images[:,2,...]], dim=1) + + return lin_rgb + + +def process(bayer_images, wbs, cam2rgbs, gamma=2.2, CRF=None): + """Processes a batch of Bayer RGBG images into sRGB images.""" + orig_img = bayer_images + # White balance. + bayer_images = apply_gains(orig_img, wbs) + # Binning + bayer_images = torch.clamp(bayer_images, min=0.0, max=1.0) + images = binning(bayer_images) + # Color correction. + images = apply_ccms(images, cam2rgbs) + # Gamma compression. + images = torch.clamp(images, min=0.0, max=1.0) + if CRF is None: + images = gamma_compression(images, gamma) + else: + images = camera_response_function(images, CRF) + + return images + + +def process_grad(bayer_images, wbs, cam2rgbs, gamma=2.2, CRF=None): + """Processes a batch of Bayer RGBG images into sRGB images.""" + orig_img = bayer_images + # White balance. + bayer_images = apply_gains(orig_img, wbs) + # Binning + bayer_images = torch.clamp(bayer_images, min=0.0, max=1.0) + images = binning(bayer_images) + # Color correction. + images = apply_ccms(images, cam2rgbs) + # Gamma compression. + images = torch.clamp(images, min=0.0, max=1.0) + if CRF is None: + images = gamma_compression_grad(images, gamma) + else: + images = camera_response_function_grad(images, CRF) + + return images + + +def camera_response_function(images, CRF): + E, fs = CRF # unpack CRF data + + outs = torch.zeros_like(images) + device = images.device + + for i in range(images.shape[0]): + img = images[i].view(3, -1) + out = Interp1d()(E.to(device), fs.to(device), img) + outs[i, ...] = out.view(3, images.shape[2], images.shape[3]) + + outs = torch.clamp((outs*255).int(), min=0, max=255).float() / 255 + return outs + +def camera_response_function_grad(images, CRF): + E, fs = CRF # unpack CRF data + + outs = torch.zeros_like(images) + device = images.device + + for i in range(images.shape[0]): + img = images[i].view(3, -1) + out = Interp1d()(E.to(device), fs.to(device), img) + outs[i, ...] = out.view(3, images.shape[2], images.shape[3]) + + return outs + +def raw2rgb(packed_raw, raw, CRF=None, gamma=2.2): + """Raw2RGB pipeline (preprocess version)""" + wb = np.array(raw.camera_whitebalance) + wb /= wb[1] + cam2rgb = raw.rgb_camera_matrix[:3, :3] + + if isinstance(packed_raw, np.ndarray): + packed_raw = torch.from_numpy(packed_raw).float() + + wb = torch.from_numpy(wb).float().to(packed_raw.device) + cam2rgb = torch.from_numpy(cam2rgb).float().to(packed_raw.device) + + out = process(packed_raw[None], wbs=wb[None], cam2rgbs=cam2rgb[None], gamma=gamma, CRF=CRF)[0, ...].numpy() + + return out + + +def raw2rgb_v2(packed_raw, wb, ccm, CRF=None, gamma=2.2): # RGBG + packed_raw = torch.from_numpy(packed_raw).float() + wb = torch.from_numpy(wb).float() + cam2rgb = torch.from_numpy(ccm).float() + out = process(packed_raw[None], wbs=wb[None], cam2rgbs=cam2rgb[None], gamma=gamma, CRF=CRF)[0, ...].numpy() + return out + + +def raw2rgb_torch(packed_raw, wb, ccm, CRF=None, gamma=2.2, batch=False): # RGBG + if batch: + out = process(packed_raw, wbs=wb, cam2rgbs=ccm, gamma=gamma, CRF=CRF) + else: + out = process(packed_raw[None], wbs=wb[None], cam2rgbs=ccm[None], gamma=gamma, CRF=CRF) + return out + +def raw2rgb_torch_grad(packed_raw, wb, ccm, CRF=None, gamma=2.2): # RGBG + out = process_grad(packed_raw, wbs=wb, cam2rgbs=ccm, gamma=gamma, CRF=CRF) + return out + +def raw2rgb_postprocess(packed_raw, raw, CRF=None): + """Raw2RGB pipeline (postprocess version)""" + assert packed_raw.ndimension() == 4 and packed_raw.shape[0] == 1 + wb = np.array(raw.camera_whitebalance) + wb /= wb[1] + cam2rgb = raw.rgb_camera_matrix[:3, :3] + + wb = torch.from_numpy(wb[None]).float().to(packed_raw.device) + cam2rgb = torch.from_numpy(cam2rgb[None]).float().to(packed_raw.device) + out = process(packed_raw, wbs=wb, cam2rgbs=cam2rgb, gamma=2.2, CRF=CRF) + return out + +def read_wb_ccm(raw): + wb = np.array(raw.camera_whitebalance) + wb /= wb[1] + wb = wb.astype(np.float32) + ccm = raw.rgb_camera_matrix[:3, :3].astype(np.float32) + return wb, ccm + + +def read_emor(address): + def _read_curve(lst): + curve = [l.strip() for l in lst] + curve = ' '.join(curve) + curve = np.array(curve.split()).astype(np.float32) + return curve + + with open(address) as f: + lines = f.readlines() + k = 1 + E = _read_curve(lines[k:k+256]) + k += 257 + f0 = _read_curve(lines[k:k+256]) + hs = [] + for _ in range(25): + k += 257 + hs.append(_read_curve(lines[k:k+256])) + + hs = np.array(hs) + + return E, f0, hs + + +def read_dorf(address): + with open(address) as f: + lines = f.readlines() + curve_names = lines[0::6] + Es = lines[3::6] + Bs = lines[5::6] + + Es = [np.array(E.strip().split()).astype(np.float32) for E in Es] + Bs = [np.array(B.strip().split()).astype(np.float32) for B in Bs] + + return curve_names, Es, Bs + + +def load_CRF(EMoR_path): + # init CRF function + fs = np.loadtxt(join(EMoR_path, 'CRF_SonyA7S2_5.txt')) + E, _, _ = read_emor(join(EMoR_path, 'emor.txt')) + E = torch.from_numpy(E).repeat(3, 1) + fs = torch.from_numpy(fs) + CRF = (E, fs) + return CRF diff --git a/ultraled/utils/registry.py b/ultraled/utils/registry.py new file mode 100644 index 0000000..5e72ef7 --- /dev/null +++ b/ultraled/utils/registry.py @@ -0,0 +1,88 @@ +# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 + + +class Registry(): + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj, suffix=None): + if isinstance(suffix, str): + name = name + '_' + suffix + + assert (name not in self._obj_map), (f"An object named '{name}' was already registered " + f"in '{self._name}' registry!") + self._obj_map[name] = obj + + def register(self, obj=None, suffix=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. + See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class): + name = func_or_class.__name__ + self._do_register(name, func_or_class, suffix) + return func_or_class + + return deco + + # used as a function call + name = obj.__name__ + self._do_register(name, obj, suffix) + + def get(self, name, suffix='basicsr'): + ret = self._obj_map.get(name) + if ret is None: + ret = self._obj_map.get(name + '_' + suffix) + print(f'Name {name} is not found, use name: {name}_{suffix}!') + if ret is None: + raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") + return ret + + def __contains__(self, name): + return name in self._obj_map + + def __iter__(self): + return iter(self._obj_map.items()) + + def keys(self): + return self._obj_map.keys() + + +DATASET_REGISTRY = Registry('dataset') +ARCH_REGISTRY = Registry('arch') +MODEL_REGISTRY = Registry('model') +LOSS_REGISTRY = Registry('loss') +METRIC_REGISTRY = Registry('metric') diff --git a/ultraled/utils/torchinterp1d.py b/ultraled/utils/torchinterp1d.py new file mode 100644 index 0000000..5e7d1ff --- /dev/null +++ b/ultraled/utils/torchinterp1d.py @@ -0,0 +1,164 @@ +import torch +import contextlib + +class Interp1d(torch.autograd.Function): + def __call__(self, x, y, xnew, out=None): + return self.forward(x, y, xnew, out) + + def forward(ctx, x, y, xnew, out=None): + """ + Linear 1D interpolation on the GPU for Pytorch. + This function returns interpolated values of a set of 1-D functions at + the desired query points `xnew`. + This function is working similarly to Matlab™ or scipy functions with + the `linear` interpolation mode on, except that it parallelises over + any number of desired interpolation problems. + The code will run on GPU if all the tensors provided are on a cuda + device. + + Parameters + ---------- + x : (N, ) or (D, N) Pytorch Tensor + A 1-D or 2-D tensor of real values. + y : (N,) or (D, N) Pytorch Tensor + A 1-D or 2-D tensor of real values. The length of `y` along its + last dimension must be the same as that of `x` + xnew : (P,) or (D, P) Pytorch Tensor + A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if + _both_ `x` and `y` are 1-D. Otherwise, its length along the first + dimension must be the same as that of whichever `x` and `y` is 2-D. + out : Pytorch Tensor, same shape as `xnew` + Tensor for the output. If None: allocated automatically. + + """ + # making the vectors at least 2D + is_flat = {} + require_grad = {} + v = {} + device = [] + eps = torch.finfo(y.dtype).eps + for name, vec in {'x': x, 'y': y, 'xnew': xnew}.items(): + assert len(vec.shape) <= 2, 'interp1d: all inputs must be '\ + 'at most 2-D.' + if len(vec.shape) == 1: + v[name] = vec[None, :] + else: + v[name] = vec + is_flat[name] = v[name].shape[0] == 1 + require_grad[name] = vec.requires_grad + device = list(set(device + [str(vec.device)])) + assert len(device) == 1, 'All parameters must be on the same device.' + device = device[0] + + # Checking for the dimensions + assert (v['x'].shape[1] == v['y'].shape[1] + and ( + v['x'].shape[0] == v['y'].shape[0] + or v['x'].shape[0] == 1 + or v['y'].shape[0] == 1 + ) + ), ("x and y must have the same number of columns, and either " + "the same number of row or one of them having only one " + "row.") + + reshaped_xnew = False + if ((v['x'].shape[0] == 1) and (v['y'].shape[0] == 1) + and (v['xnew'].shape[0] > 1)): + # if there is only one row for both x and y, there is no need to + # loop over the rows of xnew because they will all have to face the + # same interpolation problem. We should just stack them together to + # call interp1d and put them back in place afterwards. + original_xnew_shape = v['xnew'].shape + v['xnew'] = v['xnew'].contiguous().view(1, -1) + reshaped_xnew = True + + # identify the dimensions of output and check if the one provided is ok + D = max(v['x'].shape[0], v['xnew'].shape[0]) + shape_ynew = (D, v['xnew'].shape[-1]) + if out is not None: + if out.numel() != shape_ynew[0]*shape_ynew[1]: + # The output provided is of incorrect shape. + # Going for a new one + out = None + else: + ynew = out.reshape(shape_ynew) + if out is None: + ynew = torch.zeros(*shape_ynew, device=device) + + # moving everything to the desired device in case it was not there + # already (not handling the case things do not fit entirely, user will + # do it if required.) + for name in v: + v[name] = v[name].to(device) + + # calling searchsorted on the x values. + ind = ynew.long() + + # expanding xnew to match the number of rows of x in case only one xnew is + # provided + if v['xnew'].shape[0] == 1: + v['xnew'] = v['xnew'].expand(v['x'].shape[0], -1) + + torch.searchsorted(v['x'].contiguous(), + v['xnew'].contiguous(), out=ind) + + # the `-1` is because searchsorted looks for the index where the values + # must be inserted to preserve order. And we want the index of the + # preceeding value. + ind -= 1 + # we clamp the index, because the number of intervals is x.shape-1, + # and the left neighbour should hence be at most number of intervals + # -1, i.e. number of columns in x -2 + ind = torch.clamp(ind, 0, v['x'].shape[1] - 1 - 1) + + # helper function to select stuff according to the found indices. + def sel(name): + if is_flat[name]: + return v[name].contiguous().view(-1)[ind] + return torch.gather(v[name], 1, ind) + + # activating gradient storing for everything now + enable_grad = False + saved_inputs = [] + for name in ['x', 'y', 'xnew']: + if require_grad[name]: + enable_grad = True + saved_inputs += [v[name]] + else: + saved_inputs += [None, ] + # assuming x are sorted in the dimension 1, computing the slopes for + # the segments + is_flat['slopes'] = is_flat['x'] + # now we have found the indices of the neighbors, we start building the + # output. Hence, we start also activating gradient tracking + with torch.enable_grad() if enable_grad else contextlib.suppress(): + v['slopes'] = ( + (v['y'][:, 1:]-v['y'][:, :-1]) + / + (eps + (v['x'][:, 1:]-v['x'][:, :-1])) + ) + + # now build the linear interpolation + ynew = sel('y') + sel('slopes')*( + v['xnew'] - sel('x')) + + if reshaped_xnew: + ynew = ynew.view(original_xnew_shape) + + ctx.save_for_backward(ynew, *saved_inputs) + return ynew + + @staticmethod + def backward(ctx, grad_out): + inputs = ctx.saved_tensors[1:] + gradients = torch.autograd.grad( + ctx.saved_tensors[0], + [i for i in inputs if i is not None], + grad_out, retain_graph=True) + result = [None, ] * 5 + pos = 0 + for index in range(len(inputs)): + if inputs[index] is not None: + result[index] = gradients[pos] + pos += 1 + return (*result,) \ No newline at end of file From c6b31ad50a379407af2922968a97a7c662af41ec Mon Sep 17 00:00:00 2001 From: mya012 <96360296+mya012@users.noreply.github.com> Date: Mon, 20 Oct 2025 11:37:18 +0800 Subject: [PATCH 6/6] Add files via upload --- ultraled/__pycache__/__init__.cpython-38.pyc | Bin 0 -> 294 bytes ultraled/__pycache__/test.cpython-38.pyc | Bin 0 -> 1608 bytes ultraled/__pycache__/train.cpython-38.pyc | Bin 0 -> 6402 bytes ultraled/archs/__init__.py | 25 + .../archs/__pycache__/__init__.cpython-38.pyc | Bin 0 -> 1134 bytes .../__pycache__/cunet_arch.cpython-38.pyc | Bin 0 -> 6212 bytes .../__pycache__/norm_util.cpython-38.pyc | Bin 0 -> 2966 bytes .../__pycache__/unet_arch.cpython-38.pyc | Bin 0 -> 3534 bytes ultraled/archs/arch_util.py | 318 +++++++++++ ultraled/archs/cunet_arch.py | 212 +++++++ ultraled/archs/norm_util.py | 59 ++ ultraled/archs/unet_arch.py | 126 +++++ ultraled/data/__init__.py | 108 ++++ .../data/__pycache__/__init__.cpython-38.pyc | Bin 0 -> 3810 bytes .../__pycache__/collate_fn.cpython-38.pyc | Bin 0 -> 1055 bytes .../commom_noise_util.cpython-38.pyc | Bin 0 -> 11557 bytes .../__pycache__/data_sampler.cpython-38.pyc | Bin 0 -> 2147 bytes ...hdr_paired_noiseraw_dataset.cpython-38.pyc | Bin 0 -> 11868 bytes .../data/__pycache__/hdr_util.cpython-38.pyc | Bin 0 -> 7173 bytes .../noise_util_rawhdr.cpython-38.pyc | Bin 0 -> 3333 bytes .../__pycache__/part_enhance.cpython-38.pyc | Bin 0 -> 6704 bytes .../prefetch_dataloader.cpython-38.pyc | Bin 0 -> 4341 bytes .../data/__pycache__/raw_utils.cpython-38.pyc | Bin 0 -> 5433 bytes ultraled/data/collate_fn.py | 25 + ultraled/data/commom_noise_util.py | 316 +++++++++++ ultraled/data/data_sampler.py | 48 ++ ultraled/data/hdr_paired_noiseraw_dataset.py | 520 ++++++++++++++++++ ultraled/data/hdr_util.py | 234 ++++++++ ultraled/data/noise_util_rawhdr.py | 254 +++++++++ ultraled/data/part_enhance.py | 206 +++++++ ultraled/data/prefetch_dataloader.py | 125 +++++ ultraled/data/raw_utils.py | 191 +++++++ ultraled/losses/__init__.py | 31 ++ .../__pycache__/__init__.cpython-38.pyc | Bin 0 -> 1387 bytes .../__pycache__/basic_loss.cpython-38.pyc | Bin 0 -> 11680 bytes .../__pycache__/gan_loss.cpython-38.pyc | Bin 0 -> 6760 bytes .../__pycache__/loss_util.cpython-38.pyc | Bin 0 -> 4467 bytes ultraled/losses/archive/conditional_loss.py | 55 ++ ultraled/losses/basic_loss.py | 380 +++++++++++++ ultraled/losses/gan_loss.py | 208 +++++++ ultraled/losses/loss_util.py | 145 +++++ ultraled/metrics/__init__.py | 20 + .../__pycache__/__init__.cpython-38.pyc | Bin 0 -> 752 bytes .../__pycache__/metric_util.cpython-38.pyc | Bin 0 -> 1480 bytes .../metrics/__pycache__/niqe.cpython-38.pyc | Bin 0 -> 6824 bytes .../__pycache__/psnr_ssim.cpython-38.pyc | Bin 0 -> 7307 bytes ultraled/metrics/fid.py | 93 ++++ ultraled/metrics/metric_util.py | 45 ++ ultraled/metrics/niqe.py | 197 +++++++ ultraled/metrics/psnr_ssim.py | 233 ++++++++ .../metrics/test_metrics/test_psnr_ssim.py | 52 ++ 51 files changed, 4226 insertions(+) create mode 100644 ultraled/__pycache__/__init__.cpython-38.pyc create mode 100644 ultraled/__pycache__/test.cpython-38.pyc create mode 100644 ultraled/__pycache__/train.cpython-38.pyc create mode 100644 ultraled/archs/__init__.py create mode 100644 ultraled/archs/__pycache__/__init__.cpython-38.pyc create mode 100644 ultraled/archs/__pycache__/cunet_arch.cpython-38.pyc create mode 100644 ultraled/archs/__pycache__/norm_util.cpython-38.pyc create mode 100644 ultraled/archs/__pycache__/unet_arch.cpython-38.pyc create mode 100644 ultraled/archs/arch_util.py create mode 100644 ultraled/archs/cunet_arch.py create mode 100644 ultraled/archs/norm_util.py create mode 100644 ultraled/archs/unet_arch.py create mode 100644 ultraled/data/__init__.py create mode 100644 ultraled/data/__pycache__/__init__.cpython-38.pyc create mode 100644 ultraled/data/__pycache__/collate_fn.cpython-38.pyc create mode 100644 ultraled/data/__pycache__/commom_noise_util.cpython-38.pyc create mode 100644 ultraled/data/__pycache__/data_sampler.cpython-38.pyc create mode 100644 ultraled/data/__pycache__/hdr_paired_noiseraw_dataset.cpython-38.pyc create mode 100644 ultraled/data/__pycache__/hdr_util.cpython-38.pyc create mode 100644 ultraled/data/__pycache__/noise_util_rawhdr.cpython-38.pyc create mode 100644 ultraled/data/__pycache__/part_enhance.cpython-38.pyc create mode 100644 ultraled/data/__pycache__/prefetch_dataloader.cpython-38.pyc create mode 100644 ultraled/data/__pycache__/raw_utils.cpython-38.pyc create mode 100644 ultraled/data/collate_fn.py create mode 100644 ultraled/data/commom_noise_util.py create mode 100644 ultraled/data/data_sampler.py create mode 100644 ultraled/data/hdr_paired_noiseraw_dataset.py create mode 100644 ultraled/data/hdr_util.py create mode 100644 ultraled/data/noise_util_rawhdr.py create mode 100644 ultraled/data/part_enhance.py create mode 100644 ultraled/data/prefetch_dataloader.py create mode 100644 ultraled/data/raw_utils.py create mode 100644 ultraled/losses/__init__.py create mode 100644 ultraled/losses/__pycache__/__init__.cpython-38.pyc create mode 100644 ultraled/losses/__pycache__/basic_loss.cpython-38.pyc create mode 100644 ultraled/losses/__pycache__/gan_loss.cpython-38.pyc create mode 100644 ultraled/losses/__pycache__/loss_util.cpython-38.pyc create mode 100644 ultraled/losses/archive/conditional_loss.py create mode 100644 ultraled/losses/basic_loss.py create mode 100644 ultraled/losses/gan_loss.py create mode 100644 ultraled/losses/loss_util.py create mode 100644 ultraled/metrics/__init__.py create mode 100644 ultraled/metrics/__pycache__/__init__.cpython-38.pyc create mode 100644 ultraled/metrics/__pycache__/metric_util.cpython-38.pyc create mode 100644 ultraled/metrics/__pycache__/niqe.cpython-38.pyc create mode 100644 ultraled/metrics/__pycache__/psnr_ssim.cpython-38.pyc create mode 100644 ultraled/metrics/fid.py create mode 100644 ultraled/metrics/metric_util.py create mode 100644 ultraled/metrics/niqe.py create mode 100644 ultraled/metrics/psnr_ssim.py create mode 100644 ultraled/metrics/test_metrics/test_psnr_ssim.py diff --git a/ultraled/__pycache__/__init__.cpython-38.pyc b/ultraled/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bf5c66855b1560a507633f27800826254a5e679 GIT binary patch literal 294 zcmY+9F%H5o3`LWcmV$)D5m<`^xd1{8upt&;=@Lav1d-Af)#<l zsXYiX;6d=P49pz;NBkpl_2k8y7ccsf?j6@E^6@_M-pl*Fm#@0r7J~1`U;aA(Ekfvb zOT7NYKs<(_eujx5h6|K#R-9s>y}}cI8eq%wi%>*qB;qu7bWjkHq>0;yMN72Pw&4FrD;*k=WyJKbuKtn8jOkTf>ZuoO<67#_#-Tz>4#2KXHs!m)jF?AHN!Tt z?bn7if$<3p^#~*yHE;puU3=enE4=cUukpfL_%zUAgYETnbQ?FA1w+I_7!mB_A$o}z z(PVAcp5 zefu@|_g0bKZK5?^#d@!a8|$k(L@@Wjvt4&!?O#VLVl6Vz-Nv$~_nUZ)U!t49`3r9k z8e(_Pf6#~4?<#3R@bVZ=lEu;D*q&RpoR$5i1-Cf30Z<{*WHD$6cmlcecnIqejN@x4 zujhX<1Yi$Np8|R$S5(U^FP}*d(rJ9a$%L9KO-)c{f*VY)f{EqE(@fPnBt5IH#uDlONSL7Km^Kw4A@+Xst@h>y*uChx`b1{bq zJF}&fywo(;U^3CuZ2`(Rz}I*vZ4JRaD>*wiWSY$`c*&FrXG>O=CYkbbb}lk`A#I%* zzf!dc7FAwK0<^T9Y+{=Oh`Ix$D%xO>b{trn?o_6|wPhP8 zB~An=r0v&~Oz3(w?Y@#a7=K0beK7m?_(bq>{$iPx^Ao2qgT_Aj{L?2V%K{3e;LNcn z^$W8xQ0Fx-@{-$(s)sNkukW=2g5Sk`+`&CRzynOMhdtz>2!5FViM%&)$0r^EjXf+s zfK!H0DxeT41$f>8hJ4F%y-K;tly6(A`znZuy{Aj4ORd?t7r4u*O=f%d+@;P7wSmW? zx53eYd-Q*u$@ax@D7rGakIq&*r6$g%$^r7gMg&pYf(L#CFBUvw^2FM=xqNMUzV*gj^;f=D1S9nYp|fcrif5m^5P2A{tX literal 0 HcmV?d00001 diff --git a/ultraled/__pycache__/train.cpython-38.pyc b/ultraled/__pycache__/train.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61f59f39bc0387aa98462e1b35f7988bdf63bb35 GIT binary patch literal 6402 zcmai2TaOz@cJ3R=X7f5T9L`OWTX#{CIbN^rMZB_LuO;tpuvQ#PwigW>L=UTn(=&ac zs*59wXeLNzfg}$KCP3f-v9%$Aec8ACntkZk00I0QBoBU3zEe%kWfLSMy1M#Q)u~gb z&N=m+`m9#76#Uj_|GD>1Hx%XH=;Q3Kgpc>}B$lcuOkrxEIQd(3RQatrntbbyF5iY@ z;H?FPu;`T3yszn)py@$5v>Yp}I2B1Zf@)ZEYO-Dk>S4oagiWU@zl*_Axa=&;dMQ{5 zTTUxnbymZ+(+=01wQ${84>z2RaMRfgx16o;f^#9f=v)jhIhW)ZGk7hu9Xq`2Tn?`| zS3oN>>oUdvg4c5f_hIf-K9cafYFM6fvEj%6XwpZG}L(fc~sm@*R zhIjW!1t;~f(pmnVwDhQ>X2#GJdmSxn?u`6^^_VN%Aa3YR@Wc;^`^R3AxC8H3 z@nC?@D#rH@dfZD!p_f$$p6GeeVb70tS!u}QeXlRF0`qo8gUpc1sEx496Ma4z-%0l3QNV05;*lLkw(z1P z<~y;=nf+|fi|kS2M*}-&S(Y0`MwOLMc)$=!n3=TSDP}shFROo$BpxSKKIA;+ya}!G z28v9NlVN7;$9@C@)1hvD3FO}^{5oDg+};koXz+aGMuTnPC8F1dX>I@N;e+i_Ah;WN zY#T2>>JFc0jq{d~tbjq%Mxkg|RPB|aTCX&1Q#DmX-Bd5A{3crZw3mxt9hGG~a$i#! zLI#SaOtgu<*oT?&Ksi>wdh(?f%9b)yCuKqX zVOOlW0V|nUG$OUqvTUuS>a@&iIfri*R!@s48#y1;nivaSs4b;sl*?%aZ@Y;{gRw2-l>$6eUe z_vDRZscn^iMw8hjwT;Pahzt6bRsMojlhbd3Xk#s4-$CPLtNf|= zbmVi-CVA#9a!Pq{P9G`Q!w|Cyy(>VHDa+wEUHq!i2w53 zszBy>9cCis$|?&pp>T2=W1Mn62DX&d=GA=9onf$fuKru(xZMPPcH9k=#V@>p=g%cMlh=4X{(z27u6!(#UF|$ zLewimYpQkiL8*lvXh$R};*auGQM;s8wPlU}5`4}bO{8c#nn>5lA5gH9LxiV7L72kf z!gM&cG1C!ll&O~D(Bj|Fagu7Rz>4^nUKj{1x+tV@O=ceG{F@w+r1}F2Iy!MmHBpr9 zm2VAJeW@Y%;ot%H8o1Z*DbZ$LpO#piW>~arkVJGKn*4W=&WO6g+%He9neulEY6xL} zqwJ~@>6mQk1p=+8rV2vsG6J54AYvdmtUSsp?nuN72vNal?{Q#+NF>r+vxd|@ycDPb zn*rO9?~ici7n9NMuK#sbhSiRE;vafhd3O{9JvzsvW1Q0|9T7!5k4dL=I1Lc;zG_cY zp&;C0B?y%`{*Tmu(4LZP+7HF9oJ+F8saDv<-V1&oL-D0`-^kR%Keid`UVuo_kHg{p zk7*uq8`|vgb+ViJGV`T$?*(z+4U+q)EXIuW?i;)O&oPLTxLGxaR{4N)ha5GNb;#3^ z;B*S$^0!z6Bv&Ow?2niWIZKI%IYJoRH17Xi&C1L>^!r}VXJ<^^ycf9P4s-AS8l3(E zk21dce+r!U212NY?ht<4rvM%y0Kfw~HGTyg_*E2$Aqnt9X3!Ej3Cn*7THK|Oh zqAlBNsg8u^Pt)3TEd^rExd5B!tR=C;Dr!I^kYEiGP`hlI)N}fz5t*z!spDygb|(lvxi7uUWjo9DH3|R+Wk~<;VvE^?W}+r)m*t2nIlhrRuBLVHSWO!XL?kX8Df9qk zu1s3u;`HK7nXHOSX+5p5YikI25^w(Vq@9+hucd9d?&YaHQ?UN!>1Fg@OV=p7IM%qG zu1v4UmIk}_6J^qX4yp$RTK0~$r*>M~zY2;v*Tr>qW2Uj2FDRterq|Ln$-6eao=O~; zwoZ9L9~%GHv^l+za|aLD9lN!mw55$5>%5t^(+2B&r%l$=_32G^8$55#6kl2ZQNXV= zQy?KqcIP`BI+&N}wdw6yZZ*wxU0OBCNc``f;eSu+@AbKicWAa5Nv>j7yU;&rK+53d_L;H26?an$2iY~~LWge@rmwR%Bq!^E^3|ou24v`_8)Pw~Q~WsJyL5FO2LXHY1^Mfyl=;W$CU!etFWJ#s`5N}` zup9WlE%kqP6$XDN{r}`LfR>BL27izir*DyT45w4)okvd)SWQ34kq5ts2qPgAzd*H3 zDk%E!*Qp=?D08j+#Au#5L1p?)aev?xa}dMdLGQ8A1t}}-xCtOHWzUs7!R*OIvP?N& z;2nBFW>S?B>hYydDaR+Uyf}~TT^@%yFa;PbNL)2G?@D<3uFZQS*4tFvqN0NWx%m9j zp+`3hncg3cWZWMcpAiae*^74j6li7E5a28H!k9lFcRvX{0;~XT_9^+jCw%1DyD_)V zUu1O0tuJ^Ckr!;`wD$OR*M2-24r4AnhRoPYY?piQ+K*7ZYu}_}^CrEAqnq9F)??vv z@xL#c-=RfSp1C|64JG#c{uO>5l_%6-4LLAzZFS3er)Wb^&lY>09S}|vxy0T1(s5NZKiVZic$-r~&-+sF9k+jgL;=ta zx{so)yyHdvz0l%^)YYbfviehO!CUkx6ADWCc@2fUu^2$W%=qP_kN%9mM-5df zR;i%70*SV{iCWf}qltw^IRPn7AdwUL=I>KMhdh6VA}b&r1&}I3kG(v@D1ZTy2qW^o zM#4kc6K7d#G=$PU_>jvfzohJzped44nL%VLbQ={S%K~WOxX1!#!u6J9MM#h5*JN2S z=aX&F#o0oA=b6eYH}`OmTadhRJLw1)H)tz!z*T~ok}xvinHz<~(p+RZU6fIh?&UT~ zB}@U5+>^`h2XW$cw(>wk`pyj8C?+||iby?_(8(Y~N=WEV=L4cVq~d2({G1BHTUkL` zvQx}iM*Y!K zZx<}Js^37JS}$M@Q*9P5U|!6K5;V-p{{`}8X0I1ghT7}A&c6YH{}mN~LIs&^?qyn> z$OA&U4|e={;?x#*Z(Y)s%-HN;d$S$Anu-K4 zuYN)u2|q}3FO%*^a6flha%e7Ho30r1%SIUtWE`MmB5R*#)1BY*=Al8lybLs_em@{i z3$@XC*Xv~^cPEiKqs&;KPCD{HFTy>MyySL&od%v3GNmx@VM(LFyHAI3LL}O!{ErI( XovehOx3rdSR?VB{n)#O&cKE*m8zt)) literal 0 HcmV?d00001 diff --git a/ultraled/archs/__init__.py b/ultraled/archs/__init__.py new file mode 100644 index 0000000..ecff4ec --- /dev/null +++ b/ultraled/archs/__init__.py @@ -0,0 +1,25 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from ultraled.utils import get_root_logger, scandir +from ultraled.utils.registry import ARCH_REGISTRY + +__all__ = ['build_network'] + +# automatically scan and import arch modules for registry +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'ultraled.archs.{file_name}') for file_name in arch_filenames] + + +def build_network(opt): + opt = deepcopy(opt) + network_type = opt.pop('type') + net = ARCH_REGISTRY.get(network_type)(**opt) + logger = get_root_logger() + logger.info(f'Network [{net.__class__.__name__}] is created.') + return net diff --git a/ultraled/archs/__pycache__/__init__.cpython-38.pyc b/ultraled/archs/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c111cc946f26fd0dfdeb0599adb37a13cf67331 GIT binary patch literal 1134 zcmZWoO>fgc5Zzfjwv*^0y>scdiU>uAg-Vn7_6RjoMYgj|v z;USOS5)p~c686Xv;x4b3JKZwdW#@BFt&7UW~w9d%h?> zL$qBqeGT21%*pes3z)+SxvoL^238c!n}*VKf=BDu%O0Yl7%<-Ry9uN z^dvY5KcOT)dusqPr#Ke>!1`uZh_m`ZHSnEhXRYTusoohWkr=`Iy-2yaTd&9qs9-ioL0G{HpKxs3ZwN+bXM5}?;)@@zlT42})czR&8 zP1}?_BgnNa+mg5mJa6aQ1-l^W9NH?{#c#B=pkK1fnY?0Gyu4R!(I@#x`BVa z=>~qI!`tg#Qrt(>SyF!Y+QrMxwf8=_a{b1&j~n2qVF+kl@%%7u#ShU=@q!ySw=?_7 zl`x9ku$i?7)#6q6q0dt>$iikXY{sol=mt$jJ;{D5-a4L(c%l^mp$wE=<&yGQwa;R9 zM;&N_?J}m^vxIV8dCD4fJt=fu?za8d=Mks%O7vFPiW76k=bgww-^_N$>ji$ox=$58 z4UK=FKH2uet%p4~+&UTiQS3B3o`3S{dzVgPHaI52JL&S~c674Y3;oz3{6zO*QluuG zv^8ocG657;Vim?`fHjHS|D&QP3CEy{y1;zUt!R;fn_n1 z!lh2{cHm#^g!gxjz1GC56ZB}NR^NJi^^*&Y&***4zk0L2dVF=g(Hwq(zCt*;vxX=7 z2|%noQU+{SVT#bAno#>%j2Vp8JGx*(^O&c4+7lhL;pv{?nNL(H(**kXtZY{oieq}1 z_j<0eM<3avwWGzq>+{eLoT#vS4QNzCf6I+q_kAY{I$bEW z!fxO;{dz7jqF&eMiRC!7Y>tzzVd1)euSeeH28l_t0k_Ow_1(Ju@ZTlXCip$Ue2@dhX29mOsr)- zDPQxWmWQ5R2s+KX4NQ;8{kX@&Chf=JkMJBG>0RFeGWITb4fdsA(rd`440sCp)w5g# zNPJUo@Iz>U)9{{Zd>*gFjAGvMe7*ppZgNE9bVZ~;@S~(+&~mS(ul2ghj}g-vE;5%> zH@TdzKqK3w1$B7olF)_!?e`5pX}~aaxvoF6}G(JYf%Q9j>`%8wA%Ru`Y@AwbU0!BS>lKz&ub7*oTndZcb(V)n8)V?ZI#89=t$?2b}i4k`Y z=n}2%evU<<$>pa3WEbNDed8OD&?CK5!3Noxne34qJqG?NweuQ*Isx*2ewF}P^KTK5 zBbH$7{WNBxCJc* zZ|X6#b6(Y(#@?{x-pIpGXOsNJn+-pH7xN-Lc5f9QY4C3YJ`;L}eiQu}TV--{WZ_gD zZVLog9ovLx1KY%u;TYY-KPNX&PI5H_>HS;_+`MOb`6mpPg^`Xmt4!z>Q;7#YS~w^z zO|(*;&@CKjsS;121#lnH7poJN)0rit&?z3UtVv67rRB_oNk?NmHet2Dy%$mmW@SJ3#f+;;@2bdodPxT1sm9*xIKcOi z9>_x+r{|{cBJTUy={cnKegStM2W);4;4UzRU(o6Q*kiF)tC;1QE&p+yP0^|Wk7w^nseWyvJ z$#^lOGU=(5-WZx~49zxD!RevkGz3$Wc--J_7=TiQ;A^%*)KcDi?5BcZ8x8wnb>8R^6{@R6BI6U6$A`Xlp!B1g^pLgJW7yRf= z0OT)fl~r-*&9REQ#8y~E!vlN>SF~lxl~5(PqAsu{HmB~?Mp18kMDEW8%4nZ*UwN|Z z01=Gx^iTCZPL;ma*F8EIwW1=7IXW8k49f}AQqEqJ$W2H24ST0lx&0|6{psM>|wojC5a z2~VuR-;5=e=dG=6iP~k#dwtJ|{V?kAePaelpwdjmqm_(h0;}= z>#D`l|D38b6}i0nwW;qc?^}##8ksooX-P)QLV1T!?xc%?CBcHA{P2`7)D&W(jW3V_ z2<;1XQ(sj2NDPJ<4C8lUjVx@G1SC3*IWgpT6AiPHnCXlh9O=})kU#H zb{3yI9Mb&evHT9!a4A(LFi0?W)q%RJ^x+1ds&Zu7fTA8f#;!ilDb~^Wc6F3IhCE+E zZK{+4;Pli{Oe%vMT#PL!USUQl0tyx94W7fwE{m-JYVFv%gUG$hQZ8Cj21Salf>yO zm7KU<;*?$|Cd%4y%TM%X&vO&~e#?K5nEvNoxDXFW0rd`w-fkywV?WU#fzP|~Hmad} zJWkYZ3#nW$7_QO6B@iKi@~J8JLJbIguB%(At39gL(N~ zY!~#;k_mnhSY8y-)1oMjHKb*M+J}>oPC=FkNR!|s%AiCPY)Tnj7NrqgA-rGCBvYa+ zC6$b>ipq$t5Z<54B-3I_N~)qNiYPgzGky*9^e)Y16`YEw0;^@R8BvoqGZ{TAX2gut zB~Ar6>(6GgLt<9y9?IxBacD%(6W*VT7Y2)BPRwknm?gD8muW7EIjK3H(TBynn3rE6 z&N6V;Uy!y7Vmh;3$aIg01!=pO(JNw6EJ|DA90kt$OVV~pR5IJ8Om|f*N!!C2eM}r4 z(Jv6*UzV23qLf)KXPPgHWvO{2qhAt7M)b>s_g5f!MXUfjDvrvHE_Ww?51yPRLloKG z=RYS}5Rf5&cD&sC^;h|CfF~MuX}aNQ)AGl}Goo!5mmajUa9`r{b`Kl*3W>i@phrNK zM?WC!Dgml)_zwwuMBo~M>jZ8R_%Q*x^5Gv7_=JEAB3jdp`KLsFM&PFe90ETh;1ak^ zfXXWF5%39Y64)X@r+a;Md=PMoBevDrbZ}AD3i+>z)GjwSZ0%NaBPrabdq8;;#6yS@ zb~$0UY|7_i-U(!7oQ{RgC;lT+k!Ag-o0Mca`LKzWkRUQ>@#p2U`QDzBH+f`;5F&7dSi0V>}q ztm3dHdF{Y^$ftLK?M#nPy8pX8=jYHxex3kjbq9w+PV#p=<0^b={L!=m*HUMd_bqQj zwa#wT3(u{}rQ6YZOUM0RiuAXjOnr+=0G2v?!kV=Vt5UeYNm4&#FAVQfPssb!6Co~` z9o%OI97lk{HGpJ(q(g$}1oyXExC?s7by6WOPOk(lc|CC=3^^SnoXTxZMF-Clpz9HN zEb=c2qwCnj{7_!}%6qf47_idDeJ(q=Z#L!aLye5AH`5E=w`n`e{0x_W$jnNbyapr; Oe@nHD+Ei`UWd8*_<)MoJ literal 0 HcmV?d00001 diff --git a/ultraled/archs/__pycache__/norm_util.cpython-38.pyc b/ultraled/archs/__pycache__/norm_util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a06fb8a9b5772d87c779f5208562d78c84357b54 GIT binary patch literal 2966 zcmZuzOK%)S5bmDm?yT)VY~ehtXm|)RU~GuOQ&9jBf{>CGQjCOFLZiuauXpU(nRR!s zoLHL+Hss_pmmJm{kj#k-e}Ny-R}N7QNL&yS;y_i;?sya2qpt0)>aDKn`s(ZXYJNT- z(7ylSubvO1-*GbA9GH9zU3Ne)!e~x9_-o~)Pdk+2*v_rK-Ek=SoG^#E2ZXttZn^t* z#{-SWe9-t4jSrdzYl5aZ(KHynL!#gb_$Td%YO5$f2Hjzis;n%MJVncF!*3QKNMLc&@Rs$QZ(IUEd1 zp=79b%aA213A?a3QS!qu?IlIQa~bD}RN-aDyU8$DQCNx-Q#TV*tpy>pjUERR<030I z!?FwRdC0OpFC@H_470*$vSKh)a_z)}ah{e!a5*Ro27#2Y78ixPnd%wWvX=~ac=>Lk z(q1gHJ-!0t`>}Z(QRZ;VJ&7FaOL=QT48m5C8ae z{P%CaT;I?O)0Nd$H&f4~suDPMq1(_U1EI)18Pg*|$;cL$lr^@MGj>Px2%f@XWJF>Q zEbmV&{!^C5$O3zF+J$^` zp&N0W6`6`-VS!#Zzf45Z=L#&YDn;7UPFZlNy$fht>mJZJi_I2J?<)^>(FK z&2cb7U*0k~q8G^n^zAUR-4s7oBSyXhj4;x*&yx{d!hKwvpJ%=qC41 zp!UfI@QlCB`2gzj3BG02s?ML@)Ip9+5#!-iS(?YcowW8;qe`O#!R$g!ncrY2xiA0(gXhQ@*-LPbG5Zi{#5=|A0HmgN-$2M|i;LhDfio}o&M<>wjyFZZqK0UI zp*m}fMBue*pdql{L=B(FH5jQtW$^SC+TTWD(on<2IC}~7CKm{m5K46B$WAX&+#S0q zQN${g#k4%=G7#c11wxz&$f^Q*8J5mK!k__B^BQ!C4hevBcK0bb>Bg}qaFw`>B0^DP zPP~EB(~!0bh8Lkrgj|EJxB@jP^ru>lvv)yV<7~cS5x^OIUQBR?pB0F+1=eD72XqE$ zXY}H2V8skLdFBN+OfJQn;8(1nSO?J-S3ss16E=Kr1QfoH0o4r{%z|ocAHe~~j&eu# zKD|M{S$W_gJB>XBw-ww_&d6o-&^dxnrJWVzv$!$6mF$X7>ikE|XF#Nb+orZ}Wm0Ju zxkzeP0ONDGF9ri%Fmq-UnV5@4$k{78s^p_(V!B z@SD$Enc{=WhJgtLr?xY8Zw5*=LUEx%5dk9IyLei>Q>`Mmi^8|v{q^+D^z=MFYc?YS z&tLxf-;>Q{LjHrw^y7i@Yxw8C1K@+pbwUy-iMYdnIt$LqWS zZ=W~$0=$E7h)cR5bf>Mt&yJGf9nn9zavFY{d{7SHkB<3286gFE zPA=#rp`>7s9Z99bDQEC=UU-tkuC_XJRu8N`ce(e1f=^KSxXX~`sn(%kIM^Y=JX@QI z63zJfbCK$n1v^BVXB#t-nlrw^T%-ldf*m5wv#ps(?HS*2F4Cf9!48q;*(G3a7`y$B z3mL1ua3N#MGcD))aC7FX^X!U^Yx<9wxQ@Fz({esbYnFX^4sY4GhQ+zMj=Mh7a(*6f zTlVF%e8sc%<{V$q6ig`jPOVoUbTA$5Hh!5|Jfo zoDa_v<>!SM@D!8|4Xn)E%FFxn~^rO2<8LMEkz6@0a#{YdprT>cKj~$V@?6zNyvS{(P&UY^JCx8#4j z+owr({A}FIj<<^>FXH})C)>Y&@7?WjT8Lhn@a-P#X8HCwONto9t+Quh2@;RfmLC8R z8nG3+LKoON-2mK}MRo`DjTu~{>vV;kFP3}dRym`6EQ#`If`G;J5dQf&fTS1XlKjb? zF!(tWm(z(SojQ@;3YqxQwNfB`&eq8!l)ent$fPC%tw%DD5$JUpNmoXzM26Cl;VK#6 zK0tL%Q{d`guMJ#023|OmhHOAxdq)E6QedA{Ad*3v^b6S6-BU=HPCr(7qc28hP@H_& zPjW@yR}PeVrz!4XEBaK?SkV(jpXh>CjGw^&i3$p61db~Dd{h*pQ!R&SGAOiE6T{;Z zBQ|m6Bs?yXEFTH6k3G_W)1a2({z=mRBJO1z8{!h%0eijyK%6#porp$GNX>7>aVdjI zWdHZw^V`=m@x2%`bAt0x%?U=Ar1Lc!xVOlJN(#B8 za3G4epl!$3nGtV;tXv$h^7LSp11V9yNyq#!TT{cteB+HG$379dKsV|6+RMg*jxz(f z3I2Hg>*D%1Y*&(8K!IIS=p-D%@Moxj&aCc8R_QLv6UQ=8gqnde(k-DBE8RsI5@VP^ zdYbXA4kcRYF3OX@GElUd39TNCmEJ^oQZE)RT3FV{jMGa7 z3&6B&%ewY#Sbb48Dt!s%NmI9MO5e6@TF)D@sXZ60zAP6i{U*wjmTuXSu5H<}o-4AY zJ=<1amF-GjLwT~OTP{jwTP|A9TXIo*E?IqDE>-$%lqWZUxg&3Yv@DmQ(Abfv!4!Ay z3ak<3r6Nhk0 zKSg*Qp@Z-+b&r&mcf>mQ82P@qTgkg(vyvYu zXMDD2eD+Y@H}XEp2Sz?X`OwIR$~Py5a^P&3Usu6kB%bsH$G+(sG3@KK4Zn!J6q#Y$ zKk$y|dAnMTHvG&g!~+~pe!U8DVhX5C4%-A=m^`IbnIhj4LvCh%y~gE!ZnP?zFXmI)l3_$7b}jWgVl{10` 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +class DCNv2Pack(ModulatedDeformConvPack): + """Modulated deformable conv for deformable alignment. + + Different from the official DCNv2Pack, which generates offsets and masks + from the preceding features, this DCNv2Pack takes another different + features to generate offsets and masks. + + Ref: + Delving Deep into Deformable Alignment in Video Super-Resolution. + """ + + def forward(self, x, feat): + out = self.conv_offset(feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_absmean = torch.mean(torch.abs(offset)) + if offset_absmean > 50: + logger = get_root_logger() + logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.') + + if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'): + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, + self.dilation, mask) + else: + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + up = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [low, up], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * up - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple diff --git a/ultraled/archs/cunet_arch.py b/ultraled/archs/cunet_arch.py new file mode 100644 index 0000000..9f9169a --- /dev/null +++ b/ultraled/archs/cunet_arch.py @@ -0,0 +1,212 @@ +from ultraled.archs.norm_util import MultipleScaleNorm2d, ScaleNorm2d +from ultraled.utils.registry import ARCH_REGISTRY + +import torch +from torch import nn + +### Normalization +from torch.nn import Identity +from torch.nn import BatchNorm2d, InstanceNorm2d +from ultraled.archs.norm_util import LayerNorm2d + +import torch +from torch import nn +from torch.nn import functional as F +import math + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), + nn.LeakyReLU(negative_slope=0.2, inplace=True) + ) + + +class ResidualBlock(nn.Module): + def __init__(self, channels) -> None: + super().__init__() + + self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x, y): + out1 = self.conv1(x) + out11 = self.lrelu(out1) + out2 = self.conv2(out11) + out21 = out2 + y + + return out21 + + def lrelu(self, x): + outt = torch.max(0.2 * x, x) + return outt + + +@ARCH_REGISTRY.register() +class CUNetArch(nn.Module): + def __init__(self, inchannels=3, outchannels=3, channels=32) -> None: + super().__init__() + + self.conv0_1 = nn.Linear(in_features=300, out_features=256) + self.conv0_2 = nn.Linear(in_features=256, out_features=128) + + self.conv1_1 = nn.Conv2d(inchannels, channels, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + self.pool1 = nn.MaxPool2d(kernel_size=2) + + self.conv2_1 = nn.Conv2d(channels, channels * 2, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(channels * 2, channels * 2, kernel_size=3, stride=1, padding=1) + self.pool2 = nn.MaxPool2d(kernel_size=2) + + + self.conv3_1 = nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1) + self.conv3_2 = ResidualBlock(channels * 4) + self.pool3 = nn.MaxPool2d(kernel_size=2) + + self.conv4_1 = nn.Conv2d(channels * 4, channels * 8, kernel_size=3, stride=1, padding=1) + self.conv4_2 = ResidualBlock(channels * 8) + self.pool4 = nn.MaxPool2d(kernel_size=2) + + self.conv5_1 = nn.Conv2d(channels * 8, channels * 16, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(channels * 16, channels * 16, kernel_size=3, stride=1, padding=1) + + + self.upv6 = nn.ConvTranspose2d(channels * 16, channels * 8, 2, stride=2) + self.conv6_1 = nn.Conv2d(channels * 16, channels * 8, kernel_size=3, stride=1, padding=1) + self.conv6_2 = ResidualBlock(channels * 8) + + + self.upv7 = nn.ConvTranspose2d(channels * 8, channels * 4, 2, stride=2) + self.conv7_1 = nn.Conv2d(channels * 8, channels * 4, kernel_size=3, stride=1, padding=1) + self.conv7_2 = ResidualBlock(channels * 4) + + self.upv8 = nn.ConvTranspose2d(channels * 4, channels * 2, 2, stride=2) + self.conv8_1 = nn.Conv2d(channels * 4, channels * 2, kernel_size=3, stride=1, padding=1) + self.conv8_2 = nn.Conv2d(channels * 2, channels * 2, kernel_size=3, stride=1, padding=1) + + self.upv9 = nn.ConvTranspose2d(channels * 2, channels, 2, stride=2) + self.conv9_1 = nn.Conv2d(channels * 2, channels, kernel_size=3, stride=1, padding=1) + self.conv9_2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + + self.conv10_1 = nn.Conv2d(channels, outchannels, kernel_size=1, stride=1) + + def _check_and_padding(self, x): + + _, _, h, w = x.size() + stride = (2 ** (5 - 1)) + + dh = -h % stride + dw = -w % stride + + top_pad = dh // 2 + bottom_pad = dh - top_pad + left_pad = dw // 2 + right_pad = dw - left_pad + self.crop_indices = (left_pad, w+left_pad, top_pad, h+top_pad) + + padded_tensor = F.pad( + x, (left_pad, right_pad, top_pad, bottom_pad), mode="reflect" + ) + + return padded_tensor + + def _check_and_crop(self, x): + left, right, top, bottom = self.crop_indices + x = x[:, :, top:bottom, left:right] + return x + + def ratio_map_encoding(self, y): + sigma = 30 + r = torch.arange(0, 300).cuda() + + + Hr, Wr = 128, 128 + r = r.view(1, 300, 1, 1).expand(-1, -1, Hr, Wr) + y = F.interpolate(y, size=(Hr, Wr), mode='bilinear', align_corners=False) + r = torch.exp(-((r - y) ** 2) / (2 * sigma * sigma)) / (math.sqrt(2 * math.pi) * sigma) + r = torch.mul(r, 1 / y) + + return r + + + def forward(self, x, y, if_train=True): + r = self.ratio_map_encoding(y) + + if if_train: + Hc4, Wc4 = int(x.size(2) / 4), int(x.size(3) / 4) + else: + Hc4, Wc4 = int(x.size(2) / 4) + 2, int(x.size(3) / 4) + r = F.interpolate(r, size=(Hc4, Wc4), mode='bilinear', align_corners=False) + batch_size, _, H, W = r.shape + r = r.view(batch_size, -1, H * W).permute(0, 2, 1) + + + # MLP + control = self.conv0_1(r) + control = self.conv0_2(control) + control = control.permute(0, 2, 1).view(batch_size, 128, Hc4, Wc4) + + x = self._check_and_padding(x) + + + conv1 = self.lrelu(self.conv1_1(x)) + conv1 = self.lrelu(self.conv1_2(conv1)) + pool1 = self.pool1(conv1) + + conv2 = self.lrelu(self.conv2_1(pool1)) + conv2 = self.lrelu(self.conv2_2(conv2)) + pool2 = self.pool1(conv2) + + conv3 = self.lrelu(self.conv3_1(pool2)) + conv3_1 = conv3 + control + conv3 = self.lrelu(self.conv3_2(conv3, conv3_1)) + pool3 = self.pool1(conv3) + + conv4 = self.lrelu(self.conv4_1(pool3)) + conv4 = self.lrelu(self.conv4_2(conv4, conv4)) + pool4 = self.pool1(conv4) + + conv5 = self.lrelu(self.conv5_1(pool4)) + conv5 = self.lrelu(self.conv5_2(conv5)) + + up6 = self.upv6(conv5) + up6 = torch.cat([up6, conv4], 1) + conv6 = self.lrelu(self.conv6_1(up6)) + conv6 = self.lrelu(self.conv6_2(conv6, conv6)) + + up7 = self.upv7(conv6) + up7 = torch.cat([up7, conv3_1], 1) + conv7 = self.lrelu(self.conv7_1(up7)) + conv7 = self.lrelu(self.conv7_2(conv7, conv7)) + + up8 = self.upv8(conv7) + up8 = torch.cat([up8, conv2], 1) + conv8 = self.lrelu(self.conv8_1(up8)) + conv8 = self.lrelu(self.conv8_2(conv8)) + + up9 = self.upv9(conv8) + up9 = torch.cat([up9, conv1], 1) + conv9 = self.lrelu(self.conv9_1(up9)) + conv9 = self.lrelu(self.conv9_2(conv9)) + + conv10 = self.conv10_1(conv9) + out = self._check_and_crop(conv10) + return out + + + + def lrelu(self, x): + outt = torch.max(0.2 * x, x) + return outt \ No newline at end of file diff --git a/ultraled/archs/norm_util.py b/ultraled/archs/norm_util.py new file mode 100644 index 0000000..ff5b655 --- /dev/null +++ b/ultraled/archs/norm_util.py @@ -0,0 +1,59 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class LayerNorm2d(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, num_features, eps=1e-6, affine=True, data_format="channels_first", track_running_stats=False): + super().__init__() + self.weight = nn.Parameter(torch.ones(num_features)) if affine else None + self.bias = nn.Parameter(torch.zeros(num_features)) if affine else None + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.num_features = (num_features, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.num_features, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + if self.weight is not None: + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + +class ScaleNorm2d(nn.Module): + def __init__(self, num_features, bias=True, *, init_weight=1.0, init_bias=0.0) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1) * init_weight, requires_grad=True) + self.bias = nn.Parameter(torch.ones(1, num_features, 1, 1) * init_bias, requires_grad=bias) + + def forward(self, x): + return self.weight * x + self.bias + +class MultipleScaleNorm2d(nn.Module): + def __init__(self, num_features, bias=True, numbers=1, *, init_weight=1.0, init_bias=0.0) -> None: + super().__init__() + self.norms = nn.ModuleList() + for _ in range(numbers): + self.norms.append(ScaleNorm2d( + num_features, + bias=bias, + init_weight=init_weight, + init_bias=init_bias + )) + # self.deploy_norm = ScaleNorm2d(num_features, False) + # self.deploy_norm.weight.requires_grad_(False) + self.numbers = numbers + + def forward(self, x, idx=0): + assert idx < self.numbers + return self.norms[idx](x) diff --git a/ultraled/archs/unet_arch.py b/ultraled/archs/unet_arch.py new file mode 100644 index 0000000..584443a --- /dev/null +++ b/ultraled/archs/unet_arch.py @@ -0,0 +1,126 @@ +from ultraled.utils.registry import ARCH_REGISTRY + +import torch +from torch import nn +from torch.nn import functional as F + +# The same arch as "Learning to see in the dark" (CVPR 2018) +@ARCH_REGISTRY.register() +class UNetArch(nn.Module): + def __init__(self, inchannels=3, outchannels=3, channels=32) -> None: + super().__init__() + + self.conv1_1 = nn.Conv2d(inchannels, channels, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + self.pool1 = nn.MaxPool2d(kernel_size=2) + + self.conv2_1 = nn.Conv2d(channels, channels * 2, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(channels * 2, channels * 2, kernel_size=3, stride=1, padding=1) + self.pool2 = nn.MaxPool2d(kernel_size=2) + + self.conv3_1 = nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(channels * 4, channels * 4, kernel_size=3, stride=1, padding=1) + self.pool3 = nn.MaxPool2d(kernel_size=2) + + self.conv4_1 = nn.Conv2d(channels * 4, channels * 8, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(channels * 8, channels * 8, kernel_size=3, stride=1, padding=1) + self.pool4 = nn.MaxPool2d(kernel_size=2) + + self.conv5_1 = nn.Conv2d(channels * 8, channels * 16, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(channels * 16, channels * 16, kernel_size=3, stride=1, padding=1) + + self.upv6 = nn.ConvTranspose2d(channels * 16, channels * 8, 2, stride=2) + self.conv6_1 = nn.Conv2d(channels * 16, channels * 8, kernel_size=3, stride=1, padding=1) + self.conv6_2 = nn.Conv2d(channels * 8, channels * 8, kernel_size=3, stride=1, padding=1) + + self.upv7 = nn.ConvTranspose2d(channels * 8, channels * 4, 2, stride=2) + self.conv7_1 = nn.Conv2d(channels * 8, channels * 4, kernel_size=3, stride=1, padding=1) + self.conv7_2 = nn.Conv2d(channels * 4, channels * 4, kernel_size=3, stride=1, padding=1) + + self.upv8 = nn.ConvTranspose2d(channels * 4, channels * 2, 2, stride=2) + self.conv8_1 = nn.Conv2d(channels * 4, channels * 2, kernel_size=3, stride=1, padding=1) + self.conv8_2 = nn.Conv2d(channels * 2, channels * 2, kernel_size=3, stride=1, padding=1) + + self.upv9 = nn.ConvTranspose2d(channels * 2, channels, 2, stride=2) + self.conv9_1 = nn.Conv2d(channels * 2, channels, kernel_size=3, stride=1, padding=1) + self.conv9_2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + + self.conv10_1 = nn.Conv2d(channels, outchannels, kernel_size=1, stride=1) + + def _check_and_padding(self, x): + ### This function is totally writen by ChatGPT + # Calculate the required size based on the input size and required factor + _, _, h, w = x.size() + stride = (2 ** (5 - 1)) + + # Calculate the number of pixels needed to reach the required size + dh = -h % stride + dw = -w % stride + + # Calculate the amount of padding needed for each side + top_pad = dh // 2 + bottom_pad = dh - top_pad + left_pad = dw // 2 + right_pad = dw - left_pad + self.crop_indices = (left_pad, w+left_pad, top_pad, h+top_pad) + + # Pad the tensor with reflect mode + padded_tensor = F.pad( + x, (left_pad, right_pad, top_pad, bottom_pad), mode="reflect" + ) + + return padded_tensor + + def _check_and_crop(self, x): + left, right, top, bottom = self.crop_indices + x = x[:, :, top:bottom, left:right] + return x + + def forward(self, x): + x = self._check_and_padding(x) + conv1 = self.lrelu(self.conv1_1(x)) + conv1 = self.lrelu(self.conv1_2(conv1)) + pool1 = self.pool1(conv1) + + conv2 = self.lrelu(self.conv2_1(pool1)) + conv2 = self.lrelu(self.conv2_2(conv2)) + pool2 = self.pool1(conv2) + + conv3 = self.lrelu(self.conv3_1(pool2)) + conv3 = self.lrelu(self.conv3_2(conv3)) + pool3 = self.pool1(conv3) + + conv4 = self.lrelu(self.conv4_1(pool3)) + conv4 = self.lrelu(self.conv4_2(conv4)) + pool4 = self.pool1(conv4) + + conv5 = self.lrelu(self.conv5_1(pool4)) + conv5 = self.lrelu(self.conv5_2(conv5)) + + up6 = self.upv6(conv5) + up6 = torch.cat([up6, conv4], 1) + conv6 = self.lrelu(self.conv6_1(up6)) + conv6 = self.lrelu(self.conv6_2(conv6)) + + up7 = self.upv7(conv6) + up7 = torch.cat([up7, conv3], 1) + conv7 = self.lrelu(self.conv7_1(up7)) + conv7 = self.lrelu(self.conv7_2(conv7)) + + up8 = self.upv8(conv7) + up8 = torch.cat([up8, conv2], 1) + conv8 = self.lrelu(self.conv8_1(up8)) + conv8 = self.lrelu(self.conv8_2(conv8)) + + up9 = self.upv9(conv8) + up9 = torch.cat([up9, conv1], 1) + conv9 = self.lrelu(self.conv9_1(up9)) + conv9 = self.lrelu(self.conv9_2(conv9)) + + conv10 = self.conv10_1(conv9) + out = self._check_and_crop(conv10) + return out + + def lrelu(self, x): + outt = torch.max(0.2 * x, x) + return outt \ No newline at end of file diff --git a/ultraled/data/__init__.py b/ultraled/data/__init__.py new file mode 100644 index 0000000..9ae7c46 --- /dev/null +++ b/ultraled/data/__init__.py @@ -0,0 +1,108 @@ +import importlib +import numpy as np +import random +import torch +import torch.utils.data +from copy import deepcopy +from functools import partial +from os import path as osp + +from ultraled.data.prefetch_dataloader import PrefetchDataLoader +from ultraled.utils import get_root_logger, scandir +from ultraled.utils.dist_util import get_dist_info +from ultraled.utils.registry import DATASET_REGISTRY +import ultraled.data.collate_fn as basicsr_collate + +__all__ = ['build_dataset', 'build_dataloader'] + +# automatically scan and import dataset modules for registry +# scan all the files under the data folder with '_dataset' in file names +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'ultraled.data.{file_name}') for file_name in dataset_filenames] + + +def build_dataset(dataset_opt): + """Build dataset from options. + + Args: + dataset_opt (dict): Configuration for dataset. It must contain: + name (str): Dataset name. + type (str): Dataset type. + """ + dataset_opt = deepcopy(dataset_opt) + dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) + logger = get_root_logger() + logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') + return dataset + + +def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): + """Build dataloader. + + Args: + dataset (torch.utils.data.Dataset): Dataset. + dataset_opt (dict): Dataset options. It contains the following keys: + phase (str): 'train' or 'val'. + num_worker_per_gpu (int): Number of workers for each GPU. + batch_size_per_gpu (int): Training batch size for each GPU. + num_gpu (int): Number of GPUs. Used only in the train phase. + Default: 1. + dist (bool): Whether in distributed training. Used only in the train + phase. Default: False. + sampler (torch.utils.data.sampler): Data sampler. Default: None. + seed (int | None): Seed. Default: None + """ + phase = dataset_opt['phase'] + rank, _ = get_dist_info() + if phase == 'train': + if dist: # distributed training + batch_size = dataset_opt['batch_size_per_gpu'] + num_workers = dataset_opt['num_worker_per_gpu'] + else: # non-distributed training + multiplier = 1 if num_gpu == 0 else num_gpu + batch_size = dataset_opt['batch_size_per_gpu'] * multiplier + num_workers = dataset_opt['num_worker_per_gpu'] * multiplier + dataloader_args = dict( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + drop_last=True) + if sampler is None: + dataloader_args['shuffle'] = True + dataloader_args['worker_init_fn'] = partial( + worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + elif phase in ['val', 'test']: # validation + batch_size = dataset_opt.get('batch_size_per_gpu', 1) + num_workers = dataset_opt.get('num_worker_per_gpu', 0) + dataloader_args = dict(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + else: + raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") + + dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) + dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) + collate_fn_class = dataset_opt.get('collate_fn', None) + if collate_fn_class is not None: + assert hasattr(basicsr_collate, collate_fn_class) + dataloader_args['collate_fn'] = getattr(basicsr_collate, collate_fn_class)() + + prefetch_mode = dataset_opt.get('prefetch_mode') + if prefetch_mode == 'cpu': # CPUPrefetcher + num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) + logger = get_root_logger() + logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') + return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) + else: + # prefetch_mode=None: Normal dataloader + # prefetch_mode='cuda': dataloader for CUDAPrefetcher + return torch.utils.data.DataLoader(**dataloader_args) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # Set the worker seed to num_workers * rank + worker_id + seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/ultraled/data/__pycache__/__init__.cpython-38.pyc b/ultraled/data/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc94bee349db0c3fa7b7bc49b6f6436a57ada3f6 GIT binary patch literal 3810 zcmaJ^UvC@75x+ei$>Z^lC`z&{Cym!`8;MOyB|%XXu9Mb@>>xlO1hL(q*93;Fvt)-#NMVdLUcg~qR;WdDYW8H|L6TYQOY>UPDg=&$8k`digsGpt2* zw@#3^653J2ZD`sGov7)yG+hnbQOE6Qx)!cPUAG&px~sZf4|~y?yQXP7T#q)~4Na5q zO0;Pj#E7oCThP;B&e(LX@z&`}P+sOgI3s7oeT6ky>#@OC`70-6&p09Ob=KBB*L91s z4r@F%+#772tup5^ac{C7TZ5h-GIHPOZ+r{Z8vA{sY{q#yOwwbZ>!~kf;DZ2lP$mL#Vplu8a;b4D21GX7+KZh@KRthrcakRA&Z+qaa>W0pM%ePTc2V4^vALy<53ZU_L4@!eay5s$~mQ$HT>N)EOS6UKKxyt}tM4W;ly z&UVq7T@Q>7r02nE%DElFE)0_>eP;u8uR&EaOv5q6I*>0IiR0SHmx0hmuFr`yP6;z| z@{RG-JR_x9WY$?MJHTZJpl>6H(nQEGI8g0^^rD1KL$F13L7-}kf{=S?oi?ec)?h?y zvk63IL4_JVr&E;V#gfObU>(SD;Y$4bE3XE_s~3 zz=eGUm$HzaH5T)KH?uyYKc&h-;WLK@8AV*nLBB2#!=eF|s(^P0JZWX2d#Hxz4MRW6 zJWtu4J{`~NS5)(YwW?k=RITu~%jbs|+%+hfzUge6^hk|R@-4#s-VhgCJgY$1P&CZ` z0^_;%mU&GYC&s7dygs+{8Z(bf@z=Z}-hk1{mN9R@7fx{Q#5=jksxW$+)nYj3d9&ci zo9pLXnsy{0B5$yI?&KD;S>s7<-U7-wvPI79yv3Tgjl2r&hWN8IWd-=^&~81kVv{*A z-o9hl#5KbJLG&)X;FS0sB?+FugcDB33mNLepuZJd#L zC-1N;N5Cf@Yss5h#}(;-A9VBXX_H;eJD|;0z5)(xA^K%&WU$SAB_Mh0Yx9rb3$lIK zIb1pHjzBuLnFIcD%|8RsEM26y82_JR&>fkG;bbtCL6{ZSLJ7?!$d~xOfUk3JD%sDy z5Phh0pNz^02L}to+c1g->n6ofyR16dZE<3qRamu8XYkH1ZxL z;)n|`g)&a3bSH?>pZBNH0T(nG(PBK)KF$5%guegL!{?26~JzwjdA7HTT||fS7C0yBDPV4CH(kESsCc6Iwin#I3C3iAS(bw zxU(U6_bt7o_%Te<0`&CR#wQ{H(3GB0oboL?+a1vR(-dQ?1_RHiFE}j`{(6CL3ds+2 zf$Z&pa?&97A|55;SoHwl8JstdWeNFAIYV$GUvh60t9DwZR~W3hstnVq(i?W(|M`?p z`OE?+(%EZt(XF*C?u1*~pB4k2=BIDd{;I&6URf9_1qN~PI#w7cl=B-uobtOuB#^p+ zvf!kf~ zWD!NhTPWpcSW&Dnce{}2G@;ptl6?x5(Y1*~Y+^xi%r0rcYhHZ6tJQ5{o3&>aF(D~} z_gRgY&ur7yygjoEvv8gRb=P#@+k(J*h4ic;W>Sj>_Xzi22VJxHH>18lP-(tFjfs>TA|Ok-DXoE~$V zk_p@f(Z;Dao=;fM;krE10!z!pR->ta8m6D_)pH4DgaiC|NESHXyfiHM7y)*y)2<7X zalye7H{xwn@lC9LfqK?G&ksY$#_I4lgPEMF`v)2R3}JqzQ)tiC*_j8?5dMbTRg?(7 zF?tYZ04^6l9c2oh3Czp{!mn_`B`!6ES-7%q7y0@-n6_njuL1hX7qQtgTN-UG(lc$V M3w5VruQn+a literal 0 HcmV?d00001 diff --git a/ultraled/data/__pycache__/collate_fn.cpython-38.pyc b/ultraled/data/__pycache__/collate_fn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0c3e300071015590f00cdd1a5116cee32b6d4b5 GIT binary patch literal 1055 zcma)*J8u&~5P)a)?(8G6odifKB`(979n0PM#2mqMS~YN=#Pl$*h;m zl!=!cac>-tl4*d2FSBf#PNt`wl$}tvH8|V8m;u z0dpkHZNLZUq4Z?w!ao;tK{(}3ip#jXa(*t&W14w|gn>&6Bk_t4i;m_)1n-%NP6vgP zCWft7i6N_YJArG^cj)xb{cx0x`=_&H+z(BbbDlB_+lQ@?lGcgJWGEApgulJ()6-&S zWtwc9Pc#*UJCidY;8~cPM5|$cuwJ!#Lru~|}}3 zz|%}o?6wY%nkU`(>GO8G**%PpI`OmS3+?f@QrEVva5;~_(VGmh#bSpAUkq)*Y}M}n z?cWTa*GK^R!gu`A2CWrG@6t}-M-S~4tk29)<#oP5^?0E#J78;|`JEIlzah(P;q+bF dv|-k|XNfF~-xy8gOl1dr*KO~%sc*h;#V-?z^$`F7 literal 0 HcmV?d00001 diff --git a/ultraled/data/__pycache__/commom_noise_util.cpython-38.pyc b/ultraled/data/__pycache__/commom_noise_util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f5d4d3c38e5fd50262201b760e842ed7b99de34 GIT binary patch literal 11557 zcmdT~O^_T%R<5j{?w+0>&1n8Kqp`i~wY6^ljD+k6@LDSi$-?Vp83H6j?BJzob=J(Z z)PK&du{5G5cClm}0hom#*gZhN3>=6F4iFr;06_!?4jedefS`zgjorh>aN)EB5NOT! zy{hi6>Zz72iObeZWxcBWc`skSeBaB=etUAVXyN+hZ+>s<&pvHg-{DL0a`AEjcjx;^ zw53(uYN&OkNZWOL-Ik}b?#Rx z-k0!R($C}ld3j&P`wRL-yuau#BmbnfuUMtChtPplQlays7cZ^h>2-oi(5WhZj;}nP zIoy|UcYY5^*Sc=Kvp{jdtdG5y3pv7)y-*q)9M2U zPmj+H`y5))PAbny-gpEpjKaP?}U@(?Pj}Dy-~ha+3^jv*g1~G zQWL6xzbSh~MSo@>2GyC);qat$ zoVd>DoJ4CI7&PsE#9?+|EZ&;v)OZlc&>~ z8W@(Q*K=Vc9t_Kc`QW+1`G(n=$jKea$;EX>=OkK#VdZt9Z&+h%hu{DB9fg0A>Ou~#JeVsw6SYHbRV{+~P&3Wcon)7!zpC?Yddh(E#mw zW+f^z&+vn1mhmu?OzuDb7JG^ds#}rQiq!Wis`kz8X($kX_v}c0r;t~rzwKS5X>ww2 z3m^h(52n^tCoPy9z@lrNu;8ZETFH(EGT(;Z?6jH(wwXZrP~8|F(P^>za%Za*L{9c1 z%5{XRGB1Fx(#7{K(!50i`k6?An7h3Q?&h=j+~Z#~m8Nc%>$Mv`CZ&krb*k|0p;K>F zhh>?IWZA52cRIC7Gm_^rn(N7P8uWKJ0HSURXZ#M<$9%YO;?6HJQHl28~#q5TW+^EzDO3Lef#z+kvh_41PE8pL3@>Y=lp6q z6Af(s#b5o?KmS|zUp^lR!5O#txnDc;?O)%yd>&RH(|e4g5)DdT2F1LDV&*K9myv{d zzul?TTg_0t9jY%6OT9EwuD4LRBO#(HxGr$zxNT_RN3@w~Lt0Ufu@1kBYPi1jx_JeU zzWu&CEX+(K%)1CAVgtU4@;!t3Jm{325Jzb%V!;0ef}VOGgLuO>Yw$+u9d{QNv+M2V z0{c4r2O=kX4jy{LL9R<5ArlRzzURSbIE=vNtPKY-S-$xPacdr01C7=^v{p#9SLnCs zX1CaH)0H+Ml}MOFoBu9t4&=-A+d|+HNw=5pLc)EU(#5hdUd%UhLE(<|z6!>=`8!H% z{E!o75OM$sa@v2dR`m~QH4;{pisot35`&2Ccrv5S1$+&2#&7#T3XVA9k(Oyx>g8~D ztrfi0XxDuNc!94k8`ClZ1*Nfsv8}R>x)}M+X}p*je(}V1 zcjyBd^i8Q(kWQ-~K{}z%*g0irim8!j7y7aG7w`g&S&7kR@F%e_M%$Lr#^@$S+tmp8 zpivlTG}?g|YtmgCDABx*mP@&T!VLYXsLRmi44rzZFr*0cW6Tm5>8rxL$!|ju%Y3AT zYZbysUl-R=vctYv(8O6~=tPc8xxRP&6mkRDCRP(MbnH2JhqVnLBy!?j6n5q~(r4#HR)X#6!Tb1DDv!;Zv!(vW1w;S!91gNx7D0XXR4fJT~2&j|}qXEatQ2Zl% z&kelmxqVpeflX)=Z}k}zEaeS7N$Aq{5C*^3i4^}Od=%NuPk@S@M^tS{p>vU>X~^lI zQqO~20iCiVGUWW=Xm*}N5=$G)``ZcLIzl;V4BANk=urWgz3`BkW1a^Nc>T=)WcTl-Q zUYz|;s^iN14CqvMsfK=c>AIqEB6PD8k*&kNoR7EwbJsRe1lPCz+I$WV#KVYCZienw zZF9?rt+|m!v0q_+yII?S%|-ov4WEXmz)U1@0++suQoX@h2VKRTx;R5G;O=|_iEphd zZLQnBqffx|ICs?@weI3=N%l$j=HYEwc1!q{!<+1wfX;)KT@%m+(6Vy^dJMGeo`5cb z=KhHVdK~m|JuCDC=s7(v^d#s7y(sjQCTg4vbC;W&wWhyzWKwWHKqdwIp?i^_=;Z*~ zY}A@aDz_q)rA^v=8CCA!j(x=KsffvHtmcQ|Wrz&Jqmiczt2KiUGkC;Vsf7yra=-ql z%wEvie>K5jBJPDi=0mnG(<7}VTIA-}6YUq{_VM{)S~nH6F{~lSzt#S&jJEOlvD=;s zT4i6XrcCVkW`fOp+=XPW046!#ld>!cM{xYVs3-nhvM|p8G4&aL&xtYxb^y$ zF7Z&d%vg4Kq-HZ5{~spUF85TLT3ZM{Ja%u#`X2a?5*(*vj{VGs;*;e)#}j_@j}x4Z z$D9(`C9HPk|KM0yD3_a+hF>m+#d5jP(%W^?6Xi0>*W;WCZo}594L{gwX>mNJOf3sk zT_Z)uY%~*}$p#amgV|zIV=|bSLUD`D{2Y?^ad%!sQc$*!mWy?vM@|0?%$5$mIY*=fLrKQjH&Eo@}rOVy9B@GWbYn; zXhW>I6uX1z5RIpjMN$^BiA(zX1kZ&e&->uo1NdD*lO_Y-#QXh2f?ukS4?s4$uuH7V zU?zYibCTb|6z}ztYZQ}#apG?Oku_txc`Dh<4E#F+I1lJ9R|}clrHc;}jOUY#N8{TA z*fHEU88|0k@~;VYi%E7_AZNJNPt5EA{BVi=Cv2(@KOW#GzDGe1zZ=%9;!a(o;Afc| z<^9FTN$u&Jym2wI!lMUy<6>k?7iHnV#Rv{>h2~;pQcnrZ#mF%|Ei@M+Gy1sDT#WEI zmkSdvM&w8rG#4X0)aCany@bWc6XEz(p72EQBO9JLx@d1#DSYoD-8aMwcp=W*aLhAw zqg{JLP64>j!);t{p>q>k{2cQ3fmHi6+eF?0mb_d`%T~{U6K$!wbe^GG3;f1`Yrewl zpJ&1~YP9<^2pPm1bH(!=+n4h^pYuxg1kugl`WvptyRaIhy;Xi z*p~lDV4o)!Jp~U5b`h-@3WmL&kqTs?|G?dsZyIoxRN$k)S?D$@!Ir>QrpLbqhTkJo zc9eLQWcUbp#(j(^-poQzPRjqroyqMH@QkA#=5T)8!H#SeLyZFW5)7A<-FrMhdjZXk z1T_3Yq~FPfB+o|&wA17_3eda+zf>O|3(yXf9CnNZw6O%&r;@$=e*m<%Ms=5fHkn{N zpJe>VfW~u`k$^UnV7HiLmjEapUO7QeXO$6-*_vz_*Kq! z+c#Vl*P68eE8($i{D@|2qh4!c2NR@e1#slq8*zi6e54+G*^gS!N7tevw>9MX*mNCc-mEIA+ayROnx%IoOwQ z%c1Jg#h1W1F)7oVGvic3gt;RkJOW|~OHQCnmdtqmzAOn_G~?dYGwDPxhu1=1MA=MN zPT#hg;1;>^#AI3>I5=vvcX7INT;Igq=_27tNNWxUtvI~l>4$c1*t4qY_uz0Rs>je0L zUHkkGQ;Sqd4k6s_7mpZ9WqXsSxr5&w45~hcv>XTX0h14z3TKzH$d-$#!p$H;ZEp9-qm|goU8{2 z`|6pYj)nPHn`hZDU8|)QsjXLK(UEs$Z}7+`J6H9TpLX&zZ468EJL}HJ7=TG%N+nLpeUUC>gvZ zKdyvMkrLd!%2x?QB=`|8 z!Lwz+C=})k fWB7B5ZgIR&C{7ij|fkRbQyP{=<1U*n8p-3ys9XrWn=Ob%7 zTeZqw*z$GYz=>;n?7!qIr~L!CaN<2r(uTH<{N~xue((3boF7(KTLi{0fBiEW)(QC= z7t14p#bcP}5fDiv%}Brc(msWCkOg_z4=H&{WGJIEA|n;P4*QX$hon<`1GmUQhq}h0 z$S0Z732rQp5EhSNnlFGTf=c^=r2SB~y+R7AYO;Pt`Zd{*F?{Q?DO>Pu$Q3ZL>hA3q znb1Qe=O)CD0U>jq*^#hJD`RzK8Gri*&IGDWpVQ>$cm@0y>R_8JbSlS!#9Qy!u>dToUbO3hNM zw6JMee0C%Zn6g@=1svS+QX^gh?^W3}SL}GC3h#~pr3=95kJ3yGKowIW88I~}lTqcv z6FoHBUIBW9xY6B*x18ClB%N(G=hzhCAVh0G=%6zi>DQZRc>m(qjTX#eZ z+MM(5PF5x&GrPbnu6EX3%{R7dK3yPpE4e73o48R*>KKj1P&jg;bgox%S+^x>&?~U0 zssan5Dh{n+`&yUUMfhfQLcNAUI&_V+kSVl!<7k0BgO(IY13}lXpE+McYC*I zdbYZU$GRtpMe0amGK}RjQIZ)uF}#2u9N6$f{(``NK~TtpkN|O*_rQRF9KWx6c8)zL zB_W`Pa0dN#)mPP3UtQnv>;A&nSWd(5jqm>Z`bSnY?Vl+#{7WJ8X`@a!FpXI*2XeA~Fq>`K2R$MY!3_e&-@yP};bJb>id z${8k2UX7}BZ&ssqx<^~Qj(c$>@Iv=Rx8aIv6bNui%NMU+ymINPbNSMXFJHNO`ODz3 zr)t#D@XzdJ!UzdhqY1Hf&ba;#dsACw+~DS&f^E1apXC-$-(e6ph1?t`Q7LC3H_r=v zOy$zN$j9$!c7~Vu1nya0<`vv?e3DP$p6Ani2KU07nyJ}i7~F!Kz7$4Yo#cLL`SRt9 z*OoEH(2Z7hn6!i&wM3()56PTCdJLC=EBrU6X_0nY>#(lIw3xl6Ir>}LhTbt~5M87l zGsZCHG^A271K;GD;#(@0;+D##RW8jlU0r;UXB#@FCo+nQp{hDrmBUa~E~j$EUM`Pb zbK=K}H_l6{MnUB!cv|tsUc0P)iTcmqXX0PfmYU}acQVmD^}~t1c$sk>2C>T8kAZo49I zYQERJDzIHNna3gV0!LIEYwnKzfP?2fjS`t}g|366Ras#g(4X1@I?HJPB2)u?U zq7|9vZnIk3R5P)oOMU4a&SN53w2t~_==!TW+IyGI)-kEKT9}WsksC%%E#U6i&s|zR z+wvn(^<93J9y`0vg&LI0od)R}QN7_%0S@a@^Onq==Uy$U1@-2I3p5f7#4_~%EVAwA z4nE>b4yMjRH6uUq22Ix}j zcom5lCyF!ZQ0UqY*$f74YjEgTZfqLjqp``^Zpn=Hajk7`O~xklvmWEA8RV(P4BF_R z1>)c=nk(W7w5`;QZCt#3ZO0Vn!H`BX!2GWU0`^W#{Z2}dcy?xY&TW$viZtDu5hnIk z9Ij*s4z}S%)PEUQSOC#5eK}TOlbF8kCl8s9z010Q=dO-S3#p&}v@Ak7!5pzoU3@Rn zHdqWNi5qwIwh=)EVRFnltz&exj>%J9nmV(`xAyQ;J$`x*-|F!*d-&-dKf8yY!2+bn zz&FZ}OM_fLimcaI!z3E(FfF+#-O0pOmvyp`D+kkQDULbBF%<_A_v@q-Cq2YTD^6yJ zgH~g6R)Y})UC`omoLM!n#JLrrrtu7ug{todH=L$d+w@(Tf7K_Y>$}%oU#3@~#G>b( zm3bJY$n9};uP*f)uS%myAbPpp^j#?K$mL2E2-*%ZwiOO-ep`N=>!euvdCup{~XKcmTu}6EARes zu60fSukDkEY|N|R3t3D0hp4kv}@ZQ9w>^NDW zNqzg+A;atSFtQ^_Vbi#j4fzcu`s1377}CyUUcr}n6iuYvbAuKs8W`u5@> zqv`F|v zOmeW>hRRdjBu^K#_=W^8Q~P*!MjvhW|HF<3czGKD9tO|p0M zQ4zsFIv$J1X8|a5jm|i1Ua7||j&LU$v%PxRdDMxMQc6$61afGv>v0MH@fk`@;IC4v z7)QGtXWyqrr2#cW_1cvY8Fld;mCwy-_gU~0NLGTJjO66`O|3ntN(-v=m-|YmhD!f+ zU+Hu_d0!V9@F(Ij1>o(nN>^065>KJ@$zCY}`hF>;EAcc+mr+{Av&CL%Ij*R*8UsS@ zaSFAuM!TaJ$6UiD)FiK3DiCcz=Or}*gMDgQJfP~%*MHf4kJ0>JQe4`fkz3qz08ljlLf4Bd}Tzj zkbb9&JfV={ub2TFDUf1`bwmK#77ji|+$&0{A?P-|`kHO2pcv6s5*&}}J2?@y>JF+o zFi%PN{O8m>1<4x1oE+Q!9PiI57HSwcG2$x8eRS^VDpluQ#(1Q9H)Qu8;#Mii;Al4@&Bsr_HQ@Qri-bnn)s z3s<4ruTzEert3CgtpG3>b*}*eeB|#@@_?RyKvBV2*G4l(T^!|ScB`?#V$ z{3^?92@;Db-Yvw8H^Qx!>tbc}D3BI+uX{BY?mgB3-D3=!GQC^FN<4A}2<)6ryn>Y1 z;OK9nYvEf29QAzed3&&Wjjan1UwCkB{N< zn*~Hph@^5ExZOzUice|%*M~+KXVZsCP#1MaRASPLXi1Mz!ubj?6huJ>V^$ z(gnxaYE}KDL=ZkC3)qp7S3@wm9&m9HEHOB@>S(JINe-m=DUqMiQ)4~6spk4VFsH$C zzeH7NI3fc=>O3SZl+nx2=&XQwF3jPY&sgdosaiEu82ORPqTh%4o@8p-O}6srYZ5`& z&4~S?agt+9Htg{mlK*dSNCt74piMA@EGr94@N)7VB(abcszXTgXNRC>Iv@&fR)<*~ zWd-<(qf~H=h;k7YDRly5MDLKMpiN0l|F7R?d;l!|c~VQ{K=s}-JbT~)dFV?9L9m24 z3-Y0i;FF^x5YLb_&k|WCqEwes%u0Web|e67fS$8`=sBV29~3=5_VD@okic{C3Jv{t ziL4MAXmT|P7s2}sO+o>Q5%l}?!NYkBeiI!2NWt%iqm9a1e~y;u5M&tdcn(B{PaM2} zL)hgA3`Z)a|IXquFudoT#Rme0RV?H_WU2On_&qfE)cY>gBL;(oht(E7ktUHnT7lfc zKcI0W@w9?;j~*=JBSmz@76~JWL zbTF=}z%p^0`ny8}9SotV;&sZtL4;l_inoZ|B=UzuP7wJckgMX4DS4O3eIkEKi-6C3Idz=z?8;#&HNuDZijfJ zLpx_AY@}`RFu+axEzO4t==My?&rp!Yqq`{EA->DQK{oMoYJSkN3LjO_^miyI)13E^ zud_YmYjz0vY6+@!W=iS;ur$6;TDSA+8_T6uDN-tF9~-IpGt~bf;>18P($1PcEw{zWBBI{_36~ zIkGSALD%E2uE+oW>!06WSg12RU;X9}w*TfMjQuxtPCpJh>v*#!XG}856Bd*;4>$Fn6#zym`NuVJKmuc_|lc$V;0n;FKejl@|CzD8y|4l zlnalYpecFGcNXO02ST2ct;Zr*l&f(gJ|~v|c|o@EO-o*o7cuv|T$U@Sm*g6K5x*d> zd|=6|@+0`BEpM7xFJjh5NIW@k#kIuz0z`;)>+&u+Hnxh)x?Fu0@)5EZyj*d8|gmB+6rLjUuVt?l66L zOLpnI(+|;RTrE&>jsRS> zsNJ3w7tqq492(5%2$>-$KWNQ1uLghq-lSGJ!_MIDbZ4O!X%E^|EmJYKWlauKkOj1( z@1~nBo!Yb-evZL}wn3TbR6slps%HI;Fz9K%Kfj%~HB@_%lIKrtc7kCw*x~m(J_$k^ z4PppFJq!m!IZmkE48sTGC@Dt_uha_GQkSW?ib5|`0?>_-`SuszG)tj2{o+6!*&id`7pXFWb`vy_l=!Xl*Kez$4GKUGP!>ry;hTSO1-aunI zMVSq*VH>dv(Gzzu@{|Z60s0DF^Oi1=VESD5pI6c|19f8-9))N6Sb&hOlLx52p z+tkN=*zotoCH56DaVK8xKxDn1fHL@?jeE?byDQYg!k3-|{8Mb==N_ozL$LiP)`VeA zAM4d`L})v$+iRF~c{_DAhhwSCO-1_20ZdQk_} z8ynAl0;;b*721{Y!+tl`jwG6>mx21TpJMD8skf0#kj;jXiB%<`Zch!zqfC31NNYFC zRbR%PCLGmvG>Wx=d-^m@{|pN4#`~ivl}2LJDuGGqh%4pS_QQU>2Ui6yD%!NuP_LSK zB}>(>QipfdI_jVffv$M1-R-#F+s?JK*O&Qruuz>RESGhm6C^km>Up{GxB3afgIn?@ z*37P=V4}`leoeH*l4x@JEm|#p-CBW{Q=iAU`V@+pz?119ahch`d*~FzXlr5{CubG* zmViZr%>pZ7#iW2Wb0Hv9d(r}v#4*ea*yR=OmViZbj6#^5)Js^fqPX#p+8(BHrs%*O zUv+5ogv4U^kruoA>OQ~-_xdv2FRMrMm|yJ0d6-0_B6+{k1E&v>U`7UFFN`|rhS3i~UP1$)KbE?lgD8-Go8lXob#Z6`VQa%6 z+z}hA9mi~Wf|$`?Lw*dAP?U(%dDIuFAcW^ptj*A)^BYIT_(fsEa;e_~WX43a&Bn;GK?WM)~iZ$UAFchqUyMo?{Hd}^HlM7l#%trcu3~UM;Ho(|8Y6By8L;V_7Qm>%s z{qvu{y7$UU|8}Rh+4|CF{`Oz~>rU@`55N2E%WMB|r(^7j6ccJUkJD^;+Etkd&3N1^B*zFzBNFyrG|sXUaSky#qtJ&5 z+bWeA<&VUN_%*~2#E0rVjNPwJ$<{koDd7w7B)t@l5H40_+TDt}6lb3nb~+b@M}~6y zXi)eLwwG?MH1^PTWA5P`Amseef?2m9^6nv`mn(>cOwXgD&4o-~TaiI8c3v!C=+^60;TjWUKYYUeo?QA8(C@&2&X+o!2YMN0pny3DV3eq&~b+@Aw@rNRNp?;e{?tZwL zj8(a++3R;{Y%Z4ez$!W1ONSUU*B%CI+ZIi6joT1Wo7aWl)$hlytG4l7Wq8cl&oF@u z?~Dm_l9CCM!_1nyloTN1RS8CC9&2Ca6&O-mWUXEy&sN8io}1OBo0V29vxyNYoP!^ zV4|8Ro4f_7Reit^57q|Jewf4$LG;=B&rxG$f60RNOz?aJOLojTLog22UV)$?$i7}W zBG_XWv_Jq|SP)OIu<iEJOsQ zLTVlcK5SLkgDoAB{|Ftr zMUbZjGWNs<`py}o-!RBR|0G|=%>oQ}5x$Y7Gx}Q10}%*bc(`l$@XdDN8-l!O+BOBO z)SvojLmp%0U&aaGo#>b)frxRXI_%O=t(1_nqzYOyn^n;)I6re+Q;MihV17_1LRKLZ zxvzODE|*k%ZuVc|+fuT`BDX=$8$?gO$Q3E!_kaBK|Niyu%`et0Clu5xGup*|4p~mU{!FE(YU)_8`yHfIeNp2HFlW z5}=Is)eQjV&@*@RzXCiXmNR}8KLa~+=_Y+7v3L`2_B$vj1|f%bXhG^TU?FN@8Ui45I^Y2?xC3Fr_RAG+f{xF?)fq0r#$H+ zABFpZtfO{hLpD*n5*a+yp2<`Bz^J90cX7*?Vp+xhnf?u8JbQkn9Nyy#hC;(4Xa5pa6BuEfW$X3uUf{G~MK0G+jMYq@0k0+A9P4 zr7(+-Ig985JcYGL2Zq^L#T5iP*TNe5-2CkG&rd!#H&>mf-A&eG7E+i4{ypFs9hb6O zE%+#xA6z(1tY0+GrRXaCv(ii^NAq3UmPpO`LHo>hzer|FUHW-@QF7uR@lN-;IJei& zN@!7_I5W53L5Fk;QM`a2Ay_;_gq({btFX#g7@?5;0z42hnGST2LQ30QfZ&^VT#Psp z;Spwdczt}~t}xtYMFVBwwvpCpBQt=tXuMWHrlmNx5R3bh8iFDOLinbR^btHZR;*)1 zGTjcK9-?#vOpx$8k_e>}r=ZA0 zl|T8`Hy`{auAvk-n=3<^qg$f0_$9#2SYrHBMfmSPU&UuVW2rd2rr7 zau|+@R(=JcFN0*cMU5OYmEbTxreN7WY8G|lp#K;rTgVq78L)$hb0Cw_J+@%;DIMsQ z*85#G0Y-gFMwcPob#y!DO5fyUdDPui?*nA+*V=+hQW9MxKB9bE+sUwZORdv5`OZ_m zo;-Y+K_=T>@m<@LIj8i0hFrS#N;YLdy=;WYYav-t-Cc~k2&`n3f!Qc(<*aa&mF^#6 z(4S+zN#_1L_A4_N6g?Yj<%^rhVi4z>e8s$T#$=jAwvEX@Del8Y=no0?Q7T9p zPNk0xZLSjR8Wktey|JMe2=)-g-{Cdp6X^qwGRJ<~z&l*Mjv*)|Rqw)AQhi_j;R*15 zHF#2UHKO6G>g%XV4I~q(Uod}utfuKI{mVf2tt|OVHQ)7H?e@}Qd!fBhcYOE%0h6usr~m)} literal 0 HcmV?d00001 diff --git a/ultraled/data/__pycache__/noise_util_rawhdr.cpython-38.pyc b/ultraled/data/__pycache__/noise_util_rawhdr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff10642e3bef7fcb56bef4b897f86df633bbe64f GIT binary patch literal 3333 zcmZ`*&669&74Po(&}bx$*JiF7)T7S4U;Mw%|*~ zS;HF^iP~!T!d~Lmm^rm8@qg^h1=cMoPK3HycjgRhUg6cDiKT{{k z*(Y)SK*pL+Lho)otEx0Eh6C+XvuUEV7e#52)={JfVW+#F36*IpONtiq ztHc_NTkDBh&Fbg3=W8RlvL>GZ-C~kJOgYh=S&=?0RZfbYMEA6nJem$2MPO?uEoAa2 z?3$GbH`xr+LcInKbz{@LNE^H=&Zk)t-P^;gRhZc2eIAH_2mf{LYy27%=JKB@Y`t$w zVPs0=nnqz?U_d~jK|K=9C~G|jye-_Tbtu4qEgHVd8gZQgY^aQ!xv@a8K`{zmIE}Lc z9&VJCuy6Z88pA!RbdtwISN#SjtLLdRPo`UGhg4B0L}2#BHDIZ?h~u}^{f;<9KGSad ztjj1}UA@-0k@3*Cv7_pvV`7KXF91;I`W64*_5GicqyfY6fs#8Lr;8Ny{D zB!x6GajXGFpmf3PH#Bm9))Xn(7<<%iW+D$v!$3hNsh6mGnL5H~C|V%iAi^cX*j?Qn zAs9g@YGu6x#>zv-;Kw$9g`dCi%P-#Awe@Ja#>PZG$E1BUo<;^i-fT>T9@?4*8IH79 z#=*CYLoBRI#vvB=$~c^nw*FF{UQ#l^FvhX53q^`FR-uM8x`e{cT7%_j>y&>6(@O7f z5c7lHU;Ski)Oyzz0u<3UYX26K+Nhe|Ip+oeDBZ*>K`87Yv%Lf?S~Gz0BU&v9r}e6M zk@i`2r}i4_qAV%)*c@DQVH^u_6H*CRGA0`_4e@K?qnN;XOAF@XjKc2iGAq?baHppf z)QlxPrK~&bJ+mA23JBUlU^8K^g^wWB7D5g>#KE90qc)efDX628>L{Kjn@BY-XUJup zv<+>FY%ui4bb3VId|4A~7ktKH+@^KGfba2rm}Bs(F8FL4DhqyZOuk1SJ(3K1lV<*n zNrPm6(p5VRB?(qgB7p)Xh)<@R;3yDok^&YeHdbi?af;wXwMHp~I35@dhsqKtOdxg< z4=j z(J>s9=!oM-Lu1GN#h~HG!IIM|WLgge)bEWnpQ;`JP?WwkSDJsR$mca553i^!_eAL%Z>GniMh&CaV&srK>++OZePYG+3=1kk~GHh>@WFMDyO4s5n+K6}g%g zb*WvM{4E_PL$A$oR+>-LyCm2(Ti;P1(8M>vxfON`1oa_xm;AYd#CaMkIl@{+mt&5{ zJ>gomDNKCyJIH=M-$k)PW9kFUL9xzX7;ROp2PN8K0hqX7pUGR=e7l9Va#4-qqr)g$ffN+WI0 ztl3>!RVi%n<`Bwop~y)(?v)!=98&xd9Huy=iYlnRapA%x{JtLT)AkYy3Z&|s>N$FM)dSc7U$z~{cyz#oBqZAR^qsB5C*a9(#s+>A+~+tl-Xb0@6fc+WB*6G zjihFO4vI8lZOA0!Y)Z>xo25=#Ra?+Ttco?&_2oSp4U4g0ljcfd_?!K@C1~l1!-FWM zDJB`=@n9oTYfXx`dK)39Z6!OXg{PuGohI7eMid2J=&t!9^gCpi{S`fP+KNu}JGL-d zR0c^c>^Uf$4J@RLG_XPX0#4|ZCaoG4rWTyZ)5f8^I$r3mu{_@Jzr7+#c!8LJ*tOL{ zMTH8cNRa(S6(X_yeiU?YX3blE?44ZY<3V`a7m@Eq>l>c%U0AIb`e7V+o$0?Y-vgxe{)S+ z(xiPu6EwG7Q%oZh3lOx)A@6mQOmn3dH2q!Y6yv0O8zq@kN%Cp#cD-orq^*kSd&hPm zGpn1Wu>#An66ALMG}|2O-BcFEE9gh(%|udmNT9?TR;kZwcd&ZsmNb-Z!wOipxWHMc zbA8~zCKjGw7(ZuWDSK8MY z!+Z6X4*Sx0y7*9ygM9suGRMuTHpt7owDa1UDSj_=(!PQ*(bjoU8|bLXV61B>&)_+S zzFFkP`ZaPsb!&N#VTF62;3r5^rx-5Q@NPwk6|H!Ezg|f1tx3C(=$p?c`WDjWW#mYF z6KN~aBjnK*Zn_YO?q5E+T{-0l+)fbs^B5;O2SIyPU}gLrV@YW`bFh2P_`m(x({}Li zg~xpDV|Z^C$v}#(LrONp#m3QvDPeql4kgH8W*r6^7_q(z152Z&Lo>Pgfj%&C+^w_> z{pj%H(9_}P7LGjHRuPyMw1I;fo7)sGM#j(?X@hK>mDxM&x^`c``P)HG)nm+q`g~d< zjs24MwQgBGO+`^eIM4_=~P zxUR*TCY?3&J`?BGZ=*+n+apHv$8|5F%>>HEHo>aP>3q=oBUp$=n2`F2soWSb#L@Bx z#=H7pa!`?#9c?yso0B^%)oR+X49iT(DOp%!gK9h(SCrN4kxt{?(FfB~8&Z80dsC1T za(ZOSvaBv?_gQ!@uHuT6(27Q$=CCGNp{mIew`3eA_t-uCp3yS-B(L0cuumo0CtT_pRyE#jXmjcWRQKy|UKz~Fnw*`}u)?{)Jn1y9 zEX`=DY+j;tM88kpr*A^ieL4bxPWYrBJ?gxb&C6Ymq}27*`w<@ZvF8qi5{#H0t!OzZw6tz zNp5KWM<{}PvvTkB+BS~p&BX!R(ROqRKU&jv*bbsKX)XPi`nri2L_;e9N2XtVE@8p0 znK(@ZQeU{62ijI6yXp2?t;mm~Ms{moO}i`9kz48jt2BGv{?+f}oe_nF4BW={tEV}p zb_xc4cV!(*nqDWm3T6L=^|MKKBSfJ0{cT@7NmDLEBqkz0T<=B@##a&xRXsd4 z0U7+nrl5uL#0+A;n^<^k`w9GVyprhsATg;vMN>}Kk5<|u;7P_CH;dLqoETkTt1~p% zc^b?B?o1pw<*4rgkQq%3X3z*m6c9$~^*ZA7)X-{mdS0AlX>xGlYd+1%>-Ou9rNJQ+ z`(e}*341eP-xW_$myE}G*Xt+w7SMXf1!Bfz5e8yB0Ni*f=v7)i+~~T(Zv(qW33D6y zRCnooDFl_6s$G`aMFEF__L~q`Oe*8{^Z>9Bi(oU!07$C!16U(hXzZYu0u&06WnjaO zA2mw-UT3S_3*CMX>yN~%WTQHXr%60R;#r8Kx({F7mgsfe^i(BN56K z&dA6m(%aS0bSjCSP-Gkds)LX|DrBkN<{~}moktR<_ z_v8xBqwGDTlU(SjQ&vGy1D-79TF0mmu}w0JI%j1BW*}$6L79Oa6A8qpnix1R)uhdV z_^5;k%D6^pgea({Wxi_yahn4Jn`jJ8go2&7Xu($fMZvQ*E~r|4N! zM64i-KLe#AGPeFxMWD=Og>X{x{~4I7K9R=qcc6>!kWkU~Ez;BpTqcboy>KBCJ?yLC z+7qHdh4kr65ZDmkgQ%C%yGH@_VbMtZh$=jYgn0?Om9VCYT#31bKu^-!JaI&li9_P# zwt`OQ9(qSKA&NoQ5H02r7hb~u)LVPvCC5A?t(d+hsAKaHoZSDIf*mC650NJ2 zAVRrgM5s1=e4hwWOZ3aDc6lmS^rx)o5AEO+52ZhqBlJhmw=wYr5-&ouk?5Dw9&rV# z_!5biAcAM%G{q7St2VL{99`3n}C;Zc$m21SjE>XpvYWu}*@lCa#ls zodoeau|i^%L_k7;?`KF;2>e;nLWo9bjI{~GyJ25+sHRJzPohfVL|`Yl`75M|rv5Oj zO;}jrZJ_G{p===ToIcOy*<~Q^ie83aETe>wIWRTK4w1LQ)TkLF>d&1LY9C|(hk#l{ zn3{-5Y>;>uc7Bg49*LcAqn2KBh@DM^o%4cf>bVsCsIzv5UbIDmoOM!+nn8c-BG@Id jw8OKZK)33B#G=Z!1?uK3rMUkparW`3>Bl>8mCio_&e0bm literal 0 HcmV?d00001 diff --git a/ultraled/data/__pycache__/prefetch_dataloader.cpython-38.pyc b/ultraled/data/__pycache__/prefetch_dataloader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca3a084a139355369e9e424f61f15a05cc4629d9 GIT binary patch literal 4341 zcmbVPPjlPG9o+>$5Tq$pRbsnI+zwjzk2#IVG!T1TA>tweI&a@xVbz>1B~W34Jq3}7PFWSSeKiBt81BmyKCd$8aTsBx5C*&7TdA&1&bYF z?YT#Gw~Crd?4rgM6_h=ctFec&cVu;IDAnRRN_CX#G2dd%EC0b*tkdMWz89GAQ z7n$m(N!I@G_WQT*-o4Wv9jM(jSt&Y4Un}T-#SC6_dSCXkqE{N?wAre^Xq@or>;uJWRLN)M0>kBaT_V{6P#IA_W} zw#U{9AG5$2Q(lQFcgO6wdcru=03NN*cGJ@}wJSv!_miG>x6Q5gf}o%DRS@XPr-V*d z3-JD2y@NbVdTj-!gD8zf z`;*_TxATFL;XuUg7_w=fvO}eT5m9S&AYJsaO!GCu=MDT^{0`sVoYzCEoI=FWzPw0x z6cvWETeumsJ+{vNy7Ut481^`^_Y6EyDK=aZSFd=dX`9(e53gx^l&hvC5p_&EeIz+-;G01%DWXPgf2QnqF;ZQ)6^&l9gaXz5?Mu7MU2a_k& zAc)f_2#ikT8|YSEqwEKik^ZJKS|w>Ct0EaaVd8bK>Qy|~)JLeFxu`K+*}KT5P3tSR z!;NuyUB}#$d3oK+SI}D5MWDHQxm@})Qw%#GuB`UqK^Yc}B7aKqm)KUXuXgjN#;-Pd zsocu5p&0CuQ>8N6UCGqJK=krHBCDTeFtqjuH;bOOAFV$sys9Ha?5otLAoMYAhWH-{ zzc$t+*eNGd@y4;G-&sgQ(2kN#zVv^h#u;WEQFLMq08fv}8As86RYdD|qCQ zwP)|JvGq@0cwrGtj$A-Hcv?OZQvMM4Mo9=V3iRTkTi|-dbu9>@L6~LGo6(k06ovbT zNErnKBX4@AT;sE)zu258jtDTNj9g9u>=8S*PRIr9Bdd&g$LLd1)E3ev)bcuIK4oWU zNci`V${4!@i;2Z%Rgz&zavk{&WkgyLO4N^Dn9Mg4gF!iGGn7y5%h|W zE|+V9jM?BKt_~ZD)BsdZfXHFW+ONx>aF{S>e>U6tpG-d49kfHHT>!d@*9(=LS)p0A zU*N$XkS1GzggSyGpHfD9Q3-60F8MPWmz=(UVRSl+Kv$&5gxVww)FpO< zH)cmS&pWS^C{02VLsP_`1;R}{_$`u|B82IB4s4!kBcz#&0IVTp|L56s@xkNu`%@bI zn#$)zdhRSbW$K+jZme(IZ-*oCdFkM{-@7yM@D1e;!vjCU88shf{xF3vKSUS|aUPQg z7uxV;AwZs)j?|gQj?B^blEWqm*XWS6G9Mpfm}+1vXRmVXl3xW?3A3tFquus`%tcd@njl;bqe+59gk)%oRbdW~- z*yNj6qzM(nX60#!k0@!Jp@pV~Qc1$s6*@6yy4KGEC6X+a+5+p;--J=k_ceb$&*uAR z&_^;G7~>XSf|t15vzGWWm_58U$LVC1W}Jm0kZ3iOl5rGYHYW_GU{Um8a9ti%~8m&I-Xq8CPC}_@*feW{&a>`j&*aG;UKdrw?K)N1C|fGR=nkX(q~B zueV3qKc`gz;C8A1^5`*VlLYHRwpF zVq4=<5rQg&=@Y7|g0z0Ag!+cshqezO)mkYKtKK&YiVmJDJs- zd+)ht?%uiQobQ}7uM7+nG+gig_}@!olbZHd8tlGo48Daso?)8CHO@lKWvbO(UA2a5 zsMd5%)mpBlTHCeJ>S3yvOF|9S?j;bfkaI7`vj-hXH`yI_K^GCVE^Y1YC059+&IO9HRPVf>R;DZf=5Ak6> zf_9XT@p8jJ?*Jd)(A`N$9i*F2e5CV9e()XBJ+y~+h#v;;Fu#xAk5%vEQ~Uw6_w&d2 zgZv@<-xRO#hxrlo9^jAgqnMlKY(c9$zC|ZFS7FlOzTc^}JFDpHQ3PrxS9z<3Gu^pz z=+S)^cl<|?MC-9s+hpr3WlgW4#t3R?*mKv#@nyFQ{a4z=R;d$TV(_X}G=31V5OwWB0xblcsyl9Nt<^Gnh|&&|;guO|j@2C}&84a68YZXpw|E&KcvO-zCH zzo(D2{HU?o^`geH#E+9|t1`1J8tE;)W?(8|JzL zQ?ygiW)8;cq>eJHsh8Su^5nKjB};y0QqxI$dj7(h8D&LtSI$>Vcc2^l)ol$g%KVZS z)~lpHcQ9&Ky>8O3E{dSBl*mH6n{>KKb+Hpy9BE~m9OnK$U0TX&L=mqniydsEzYA&9 zuAPuh?OJlS9kzuWID7ep%kx!i=fs%{^Od}ygU(c|2Q8^DhOT}+6c5un?4o3^!3Y|Q zO05}rfaD(qnZX^8gJ@7ni4~x%apo{x({I~46htyL@}BLhgbpc8>E;7?6zPI9f(->7 z7W823;}{7`06jQ&R)(wJV1O61X{BaSgI)no03xlK&cy~3sj8PYa@uB>{S7IX?O)-H1EVAkpfF{e0opOXXpSnK(u z)Er1@KkC;oDGY2F>jm%!(*o@TA))}QDWrwF*UbL)9_u2Aqz>&7whM@b)s`l-!y34S zG(V~JhPd8fd;o8j$ryKJ@K(bB#bSzIZ-m>KT=qP;g|t}K`usth2=)d09a@_{6NDZx z%L}LLL5L3Qw-fqs=EO^ab~Lkf8<)sIi4EMi{&cFM3vxJVbN^aU^OcQn>ELV`#0H+# zL%)_(3~2=s3|Sf-Ag9>^B2-Lql4Od@fJp#ihF@_a0`AKk#CQ-jw*Tl>;~Vck^_LG; zXP<`yjY{)c;9r+U5G7d4OO~Vs1Z@jJ1}U8|h+;U1ueK(RgC`EqgdJb)dIBgNUloZs zN-TC>Jfo(%QN10~&U`PDhQHF0M%-;xO0Jdo(WS`E;lCEhL-QCf1dWy#-|5!GLu!#4 z_P}G;hRDUXQ0=R}a2<@#1#uGZ^lEDEayLcYKs0@n zjT)nHJ5w{^Ze>rig%k0BbjZa@b zUirzhf4V@xvp+(@07(m82lzC4nYoC;Lw^SV~aU03h_WNd0 zC6^^1%IU-j9I1Gk$WtK5R7{HB6a{lAS?wLOe;U5~fQ75oez%3Z&+mhr4fmh`=fe1NLEq^eeN+*-`p^sT=y2wqbs0 zTI#7iWY;K=y?w+!yl10Y6 z+l*3#rQKtdZsd1PBA?~~&DK4Q0D30~6E6@^7LcWp&wNdM8!Tz9`Jx?H3U_=|vg#d0 zi_7%fOGI8KG7sXa%(adv0L0pF1dl-A8t(WIh=ydWq;DI-IC~=d86{ni++C%`4jR0N z*@TW6nb(iRI!joO;xeIv(KLIOqKN97;KAF;FyK@0^w*PI;=osNS7d;hee2OO_a8vbY!SY>N-KvhJjbp=K$veKx|)h+VDT5TgZdsq(ZHREG3?3_(@e212PxHRb(m4k}XMgU4&R`zhRWJ z=d&XAXAp?#Ei~lu!%TG1+hY*)wmWb50R}2ogElGGhGMz4?1&=?B4?`oRQp)J!QN&G zl;5;&sAw;~mDo+xSfuEUtnSJn*n)q*iPhlWSS_zs!)Ko5ym%Yq0@Z1pvTnwabJt0x z$Q|U^RL_;L=D;R>@W{`tjh;n8-A<;-(fZX_7B$Ig#3dpsekz+m0Mf)VNM%r+w4nGT zhcCT|qdixBW$x_7({mSQ&&iSVaMZbWa=zV-_^c49ql^&vsL)pdZxODC&k0qSV=OAP zO25!Tk#z#amT1#oMMM-@52IHx#CJ%txUDoOQgld!is+0`wIvJqHb9jzpy}rr$7K)= z@z+5eOtmoG4x7?N3jQA8r5D^$7@>BiFnkw53WT!6*fdaFAx7fx%(?8K)S=Mv!2Bqi z7F3TT9rs16oA@%{jgUn9K8%jtbGJj;2N6`T2qA!?-{A3B<%UeWMawerplZ%n*j(j7 zaRdVuy@X4>H$bEVGwygI2L3g|_9DOzNGwV;Pb&#ZgtN_&NlL4;8eH{&<9?<_iZ$+- z`T#N$MKkqoR3m=^kQVOr#TihRSI?oyvpq|%}D94ek}=s!=#8$S=?K!4l65_1gP%z+;1hs@q| aCJIB&l;VD+aw@xgpgiu3IQfEA82&%%10+-c literal 0 HcmV?d00001 diff --git a/ultraled/data/collate_fn.py b/ultraled/data/collate_fn.py new file mode 100644 index 0000000..1fde0da --- /dev/null +++ b/ultraled/data/collate_fn.py @@ -0,0 +1,25 @@ +from torch.utils.data._utils.collate import collate, default_collate_fn_map +from copy import deepcopy +import torch + +DEFAUT_COLLATE_FN_MAP = default_collate_fn_map + +def starlight_collate_fn(): + collate_fn_map = deepcopy(DEFAUT_COLLATE_FN_MAP) + + def collate_tensor_fn(batch, *, collate_fn_map): + return torch.cat(batch, 0) + + def collate_list_fn(batch, *, collate_fn_map): + L = [] + for l in batch: + L.extend(l) + return L + + collate_fn_map[torch.Tensor] = collate_tensor_fn + collate_fn_map[list] = collate_list_fn + + def _collate_fn(batch): + return collate(batch, collate_fn_map=collate_fn_map) + + return _collate_fn \ No newline at end of file diff --git a/ultraled/data/commom_noise_util.py b/ultraled/data/commom_noise_util.py new file mode 100644 index 0000000..805d379 --- /dev/null +++ b/ultraled/data/commom_noise_util.py @@ -0,0 +1,316 @@ +from abc import ABC +import math +import torch +import numpy as np +from scipy import stats + +def _unpack_bayer(x): + _, h, w = x.shape + H = h*2 + W = w*2 + out = np.zeros((H, W)) + + out[0:H:2, 0:W:2] = x[0] + out[0:H:2, 1:W:2] = x[1] + out[1:H:2, 1:W:2] = x[2] + out[1:H:2, 0:W:2] = x[3] + return out + +def _pack_bayer(raw): + h, w = raw.shape + out = np.concatenate((raw[np.newaxis, 0:h:2, 0:w:2], + raw[np.newaxis, 0:h:2, 1:w:2], + raw[np.newaxis, 1:h:2, 1:w:2], + raw[np.newaxis, 1:h:2, 0:w:2]), axis=0) + return out + +def _unpack_bayer_torch(x): + _, h, w = x.shape + H = h*2 + W = w*2 + out = torch.zeros((H, W)) + + out[0:H:2, 0:W:2] = x[0] + out[0:H:2, 1:W:2] = x[1] + out[1:H:2, 1:W:2] = x[2] + out[1:H:2, 0:W:2] = x[3] + return out + +# def _pack_bayer_torch(raw): +# h, w = raw.shape +# out = torch.cat((raw[None, 0:h:2, 0:w:2], # R +# raw[None, 0:h:2, 1:w:2], # G1 +# raw[None, 1:h:2, 1:w:2], # B +# raw[None, 1:h:2, 0:w:2] # G2 +# ), dim=0) +# return out + +def _pack_bayer_torch(raw): + h, w = raw.size(-2), raw.size(-1) + out = torch.cat((raw[..., None, 0:h:2, 0:w:2], # R + raw[..., None, 0:h:2, 1:w:2], # G1 + raw[..., None, 1:h:2, 1:w:2], # B + raw[..., None, 1:h:2, 0:w:2] # G2 + ), dim=-3) + return out + +def _pack_batch_bayer_torch(raw): + _, h, w = raw.shape + out = torch.cat((raw[:, None, 0:h:2, 0:w:2], # R + raw[:, None, 0:h:2, 1:w:2], # G1 + raw[:, None, 1:h:2, 1:w:2], # B + raw[:, None, 1:h:2, 0:w:2] # G2 + ), dim=1) + return out + +def torch_shot_noise(x, k): + return torch.poisson(x / k) * k - x + +def torch_gaussian_noise(x, scale, loc=0): + return torch.randn_like(x) * scale + loc + # return torch.zeros_like(x).normal_(loc, scale) + +def torch_turkey_lambda_noise(x, scale, t_lambda=1.4): + def turkey_lambda_ppf(p, t_lambda): + # assert not torch.any(torch.tensor(t_lambda == 0.0)) + return 1 / t_lambda * (p ** t_lambda - (1 - p) ** t_lambda) + + epsilon = 1e-10 + U = torch.rand_like(x) * (1 - 2 * epsilon) + epsilon + Y = turkey_lambda_ppf(U, t_lambda + 1e-8) * scale + + return Y + +def torch_quant_noise(x, q): + return (torch.rand_like(x) - 0.5) * q + +# def torch_row_noise(x, scale, loc=0): +# _, H, W = x.shape +# noise = torch.zeros((H * 2, 1), device=x.device).normal_(loc, scale).repeat((1, W * 2)) +# return _pack_bayer_torch(noise) + +def torch_row_noise(x, scale, loc=0): + if x.dim() == 4: + B, _, H, W = x.shape + noise = (torch.randn((B, H * 2, 1), device=x.device) * scale + loc).repeat((1, 1, W * 2)) + elif x.dim() == 5: + B, T, _, H, W = x.shape + noise = (torch.randn((B, T, H * 2, 1), device=x.device) * scale + loc).repeat((1, 1, 1, W * 2)) + elif x.dim() == 3: + _, H, W = x.shape + noise = torch.zeros((H * 2, 1), device=x.device).normal_(loc, scale).repeat((1, W * 2)) + else: + raise NotImplementedError() + return _pack_bayer_torch(noise) + +def torch_batch_row_noise(x, scale, loc=0): + B, _, H, W = x.shape + noise = (torch.randn((B, H * 2, 1), device=x.device) * scale + loc).repeat((1, 1, W * 2)) + return _pack_batch_bayer_torch(noise) + +def numpy_shot_noise(x, k): + # print(x, k) + return np.random.poisson(x / k).astype(np.float32) * k - x + +def numpy_gaussian_noise(x, scale): + return stats.norm.rvs(scale=scale, size=x.shape).astype(np.float32) + +def numpy_turkey_lambda_noise(x, scale, t_lambda=1.4): + return stats.tukeylambda.rvs(t_lambda, scale=scale, size=[*x.shape]).astype(np.float32) + +def numpy_row_noise(x, scale): + _, H, W = x.shape + noise = np.random.randn(H * 2, 1).astype(np.float32) * scale + noise = np.repeat(noise, W * 2, 1) + return _pack_bayer(noise) + +def numpy_quant_noise(x, q): + return np.random.uniform(low=-0.5*q, high=0.5*q, size=x.shape) + +class Engine(ABC): + @staticmethod + def uniform(min, max, shape=None): + pass + + @staticmethod + def randint(min, max, shape=None): + pass + + @staticmethod + def randn(shape=None): + pass + + @staticmethod + def log(x): + pass + + @staticmethod + def exp(x): + pass + + @staticmethod + def to_engine_type(x): + pass + + @staticmethod + def shot_noise(x, k): + pass + + @staticmethod + def gaussian_noise(x, scale): + pass + + @staticmethod + def turkey_lambda_noise(x, scale, t_lambda): + pass + + @staticmethod + def row_noise(x, scale): + pass + + @staticmethod + def quant_noise(x, q): + pass + +class NumpyEngine(Engine): + @staticmethod + def uniform(min, max, shape=None): + if shape == None: + return np.random.uniform(min, max) + return np.random.uniform(min, max, size=shape) + + @staticmethod + def randint(min, max, shape=None): + if shape == None: + return np.random.randint(min, max) + return np.random.randint(min, max, size=shape) + + @staticmethod + def randn(shape=None): + if shape == None: + return np.random.randn() + return np.random.randn(shape) + + @staticmethod + def log(x): + return np.log(x) + + @staticmethod + def exp(x): + return np.exp(x) + + @staticmethod + def to_engine_type(x): + return np.array(x) + + @staticmethod + def shot_noise(x, k): + return numpy_shot_noise(x, k) + + @staticmethod + def gaussian_noise(x, scale): + return numpy_gaussian_noise(x, scale) + + @staticmethod + def turkey_lambda_noise(x, scale, t_lambda): + return numpy_turkey_lambda_noise(x, scale, t_lambda) + + @staticmethod + def row_noise(x, scale): + return numpy_row_noise(x, scale) + + @staticmethod + def quant_noise(x, q): + return numpy_quant_noise(x, q) + +class TorchEngine(Engine): + @staticmethod + def uniform(min, max, shape=1, device='cpu'): + if shape != 1: + return torch.rand((shape,), device=device) * (max - min) + min + return torch.rand((shape,)).item() * (max - min) + min + + @staticmethod + def randint(min, max, shape=1, device='cpu'): + if shape != 1: + return torch.randint(min, max, (shape,), device=device) + return torch.randint(min, max, (shape,)).item() + + @staticmethod + def randn(shape=1, device='cpu'): + if shape != 1: + return torch.randn((shape,), device=device) + return torch.randn((shape,)).item() + + @staticmethod + def log(x): + return math.log(x) + + @staticmethod + def exp(x): + return math.exp(x) + + @staticmethod + def to_engine_type(x): + return torch.tensor(x) + + @staticmethod + def shot_noise(x, k): + return torch_shot_noise(x, k) + + @staticmethod + def gaussian_noise(x, scale): + return torch_gaussian_noise(x, scale) + + @staticmethod + def turkey_lambda_noise(x, scale, t_lambda): + return torch_turkey_lambda_noise(x, scale, t_lambda) + + @staticmethod + def row_noise(x, scale): + return torch_row_noise(x, scale) + + @staticmethod + def quant_noise(x, q): + return torch_quant_noise(x, q) + + +class TorchBatchEngine(TorchEngine): + def __init__(self, use_hflip=True, use_rot=True) -> None: + super().__init__() + self.use_hflip = use_hflip + self.use_rot = use_rot + + @staticmethod + def uniform(min, max, shape=None, device='cpu'): + if shape is not None: + return torch.rand((shape,), device=device) * (max - min) + min + return torch.rand((1,)).item() * (max - min) + min + + @staticmethod + def turkey_lambda_noise(x, scale, t_lambda): + return torch_turkey_lambda_noise(x, scale, t_lambda) + + @staticmethod + def log(x): + return torch.log(x) + + @staticmethod + def exp(x): + return torch.exp(x) + + @staticmethod + def row_noise(x, scale): + return torch_batch_row_noise(x, scale) + + def augment(self, *datas): + hflip = self.use_hflip and self.randint(0, 2) == 1 + vflip = self.use_rot and self.randint(0, 2) == 1 + rot90 = self.use_rot and self.randint(0, 2) == 1 + if hflip: + datas = [torch.flip(data, (3,)) for data in datas] + if vflip: + datas = [torch.flip(data, (2,)) for data in datas] + if rot90: + datas = [torch.permute(data, (0, 1, 3, 2)) for data in datas] + return datas \ No newline at end of file diff --git a/ultraled/data/data_sampler.py b/ultraled/data/data_sampler.py new file mode 100644 index 0000000..575452d --- /dev/null +++ b/ultraled/data/data_sampler.py @@ -0,0 +1,48 @@ +import math +import torch +from torch.utils.data.sampler import Sampler + + +class EnlargedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + Modified from torch.utils.data.distributed.DistributedSampler + Support enlarging the dataset for iteration-based training, for saving + time when restart the dataloader after each epoch + + Args: + dataset (torch.utils.data.Dataset): Dataset used for sampling. + num_replicas (int | None): Number of processes participating in + the training. It is usually the world_size. + rank (int | None): Rank of the current process within num_replicas. + ratio (int): Enlarging ratio. Default: 1. + """ + + def __init__(self, dataset, num_replicas, rank, ratio=1): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dataset_size = len(self.dataset) + indices = [v % dataset_size for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/ultraled/data/hdr_paired_noiseraw_dataset.py b/ultraled/data/hdr_paired_noiseraw_dataset.py new file mode 100644 index 0000000..9b58c36 --- /dev/null +++ b/ultraled/data/hdr_paired_noiseraw_dataset.py @@ -0,0 +1,520 @@ +import re +import math +from torch.utils import data as data +import numpy as np +import torch +from os import path as osp +from tqdm import tqdm +from ultraled.data.noise_util_rawhdr import NoiseGenerator + + + +from ultraled.utils.registry import DATASET_REGISTRY + + +import torch +import math + +import random +from ultraled.data.part_enhance import * + + +@DATASET_REGISTRY.register() +class EstimatorHDRRAWDataset(data.Dataset): + def __init__(self, opt) -> None: + super().__init__() + self.opt = opt + # print(self.opt) + self.engine = self.opt.get('engine', 'torch') + + self.root_folder = opt['dataroot'] + self.postfix = opt.get('postfix', None) + self.which_meta = opt.get('which_meta', 'gt') + self.zero_clip = 0 if opt.get('zero_clip', True) else None + self.ratio_range = opt.get('ratio_range', (100, 300)) + self.use_patches = opt.get('use_patches', False) + self.load_in_mem = opt.get('load_in_mem', False) + if self.use_patches: + + self.patch_id_max = opt.get('patch_id_max', 8) + self.patch_tplt = opt.get('patch_tplt', '_s{:03}') + + assert self.postfix == 'npz' + + + self.lq_paths, self.gt_paths = [], [] + with open(opt['data_pair_list'], 'r') as data_pair_list: + pairs = data_pair_list.readlines() + for pair in pairs: + lq, gt = pair.split(' ')[:2] + gt = gt.rstrip('\n') + + if not self.use_patches: + self.lq_paths.append(osp.join(self.root_folder, lq)) + self.gt_paths.append(osp.join(self.root_folder, gt)) + + else: + for i in range(1, 1 + self.patch_id_max): + self.lq_paths.append(osp.join(self.root_folder, self.insert_patch_id(lq, i, self.patch_tplt))) + self.gt_paths.append(osp.join(self.root_folder, self.insert_patch_id(gt, i, self.patch_tplt))) + + if self.load_in_mem: + self.lqs = { + '.'.join([data_path, self.postfix]): \ + self.depack_meta('.'.join([data_path, self.postfix]), self.postfix) + for data_path in tqdm(self.lq_paths, desc='load lq metas in mem...') + } + self.gts = { + '.'.join([data_path, self.postfix]): \ + self.depack_meta('.'.join([data_path, self.postfix]), self.postfix) + for data_path in tqdm(self.gt_paths, desc='load gt metas in mem...') + } + + @staticmethod + def insert_patch_id(path, patch_id, tplt='_s{:03}'): + exts = path.split('.') + base = exts.pop(0) + while exts[0] != 'ARW': + base += '.' + exts.pop(0) + base = base + tplt.format(patch_id) + return base + '.' + '.'.join(exts) + + @staticmethod + def depack_meta(meta_path, postfix='npz', to_tensor=True): + if postfix == 'npz': + meta = np.load(meta_path, allow_pickle=True) + black_level = np.ascontiguousarray(meta['black_level'].copy().astype('float32')) + white_level = np.ascontiguousarray(meta['white_level'].copy().astype('float32')) + im = np.ascontiguousarray(meta['im'].copy().astype('float32')) + wb = np.ascontiguousarray(meta['wb'].copy().astype('float32')) + ccm = np.ascontiguousarray(meta['ccm'].copy().astype('float32')) + meta.close() + elif postfix == None: + ## using rawpy + raise NotImplementedError + else: + raise NotImplementedError + + if to_tensor: + im = torch.from_numpy(im).float().contiguous() + black_level = torch.from_numpy(black_level).float().contiguous() + white_level = torch.from_numpy(white_level).float().contiguous() + wb = torch.from_numpy(wb).float().contiguous() + ccm = torch.from_numpy(ccm).float().contiguous() + + return (im - black_level) / (white_level - black_level), \ + wb, ccm + + @staticmethod + def depack_meta_gt(meta_path, postfix='npz', to_tensor=True): + if postfix == 'npz': + meta = np.load(meta_path, allow_pickle=True) + black_level = np.ascontiguousarray(meta['black_level'].copy().astype('float32')) + white_level = np.ascontiguousarray(meta['white_level'].copy().astype('float32')) + im = np.ascontiguousarray(meta['im'].copy().astype('float32')) + wb = np.ascontiguousarray(meta['wb'].copy().astype('float32')) + ccm = np.ascontiguousarray(meta['ccm'].copy().astype('float32')) + meta.close() + elif postfix == None: + ## using rawpy + raise NotImplementedError + else: + raise NotImplementedError + + if to_tensor: + im = torch.from_numpy(im).float().contiguous() + black_level = torch.from_numpy(black_level).float().contiguous() + white_level = torch.from_numpy(white_level).float().contiguous() + wb = torch.from_numpy(wb).float().contiguous() + ccm = torch.from_numpy(ccm).float().contiguous() + + return im, \ + wb, ccm + + + def randint(self, *range): + if self.engine == 'torch': + return torch.randint(*range, size=(1,)).item() + else: + return np.random.randint(*range) + + def flip(self, x, dim): + if self.engine == 'torch': + return torch.flip(x, (dim,)) + else: + return np.flip(x, dim) + + def transpose(self, x): + if self.engine == 'torch': + return torch.permute(x, (0, 2, 1)) + else: + return np.transpose(x, (0, 2, 1)) + + def __getitem__(self, index): + + def sum_img_and_noise(img, noises): + for noise in noises: + img += noise + return img + + + lq_path = self.lq_paths[index] + gt_path = self.gt_paths[index] + ratio = self.randint(*self.ratio_range) + + if self.postfix is not None: + lq_path = '.'.join([lq_path, self.postfix]) + gt_path = '.'.join([gt_path, self.postfix]) + + + if not self.load_in_mem: + lq_im, lq_wb, lq_ccm = self.depack_meta(lq_path, self.postfix) + gt_im, gt_wb, gt_ccm = self.depack_meta_gt(gt_path, self.postfix) + else: + lq_im, lq_wb, lq_ccm = self.lqs[lq_path] + gt_im, gt_wb, gt_ccm = self.gts[gt_path] + + + ## crop + if self.opt['crop_size'] is not None: + _, H, W = lq_im.shape + crop_size = self.opt['crop_size'] + assert crop_size <= H and crop_size <= W + if self.opt['phase'] == 'train': + h_start = torch.randint(0, H - crop_size, (1,)).item() + w_start = torch.randint(0, W - crop_size, (1,)).item() + else: + # center crop + h_start = (H - crop_size) // 2 + w_start = (W - crop_size) // 2 + lq_im_patch = lq_im[:, h_start:h_start+crop_size, w_start:w_start+crop_size] + gt_im_patch = gt_im[:, h_start:h_start+crop_size, w_start:w_start+crop_size] + else: + lq_im_patch = lq_im + gt_im_patch = gt_im + ## flip + rotate + if self.opt['phase'] == 'train': + hflip = self.opt['use_hflip'] and torch.rand((1,)).item() < 0.5 + vflip = self.opt['use_rot'] and torch.rand((1,)).item() < 0.5 + rot90 = self.opt['use_rot'] and torch.rand((1,)).item() < 0.5 + if hflip: + lq_im_patch = torch.flip(lq_im_patch, (2,)) + gt_im_patch = torch.flip(gt_im_patch, (2,)) + if vflip: + lq_im_patch = torch.flip(lq_im_patch, (1,)) + gt_im_patch = torch.flip(gt_im_patch, (1,)) + if rot90: + lq_im_patch = torch.permute(lq_im_patch, (0, 2, 1)) + gt_im_patch = torch.permute(gt_im_patch, (0, 2, 1)) + + if self.opt.get('ratio_aug') is not None: + ratio_range = self.opt['ratio_aug'] + rand_ratio = torch.rand((1,)).item() * (ratio_range[1] - ratio_range[0]) + ratio_range[0] + + gt_im_patch = gt_im_patch / ratio * rand_ratio + ratio = rand_ratio + + + + im_ratio = gt_im_patch.mean() // lq_im_patch.mean() + lq_im_patch = gt_im_patch / im_ratio + ratio_all = im_ratio * ratio + _, H, W = gt_im_patch.shape + + Highlight_Generator = EstimatorHighlightGenerator(ratio_all) + mask, addmap, lq_im_patch1 = Highlight_Generator.generate_highlight(lq_im_patch) + mask = mask.unsqueeze(0) + + lq_im_patch = lq_im_patch.to(gt_im_patch.device) + lq_im_patch = lq_im_patch * addmap + + map_ratio = (lq_im_patch * ratio_all ) / (gt_im_patch * ratio + 1e-8) + ratiomap_output = torch.mean(map_ratio, dim=0, keepdim=True).unsqueeze(0) * addmap + + lq_ev0_non_zero = lq_im_patch * im_ratio + exposures = [ + torch.clip(lq_ev0_non_zero, min=0, max=1), + torch.clip(lq_ev0_non_zero / 200, min=0, max=1), + torch.clip(lq_ev0_non_zero / 120, min=0, max=1), + torch.clip(lq_ev0_non_zero / 60, min=0, max=1), + torch.clip(lq_ev0_non_zero / 32, min=0, max=1), + torch.clip(lq_ev0_non_zero / 16, min=0, max=1), + torch.clip(lq_ev0_non_zero / 8, min=0, max=1), + torch.clip(lq_ev0_non_zero / 4, min=0, max=1), + torch.clip(lq_ev0_non_zero / 2, min=0, max=1), + ] + + exposures_gt = torch.stack(exposures) + lq_nonoise = lq_im_patch * im_ratio + gt_im_patch = torch.clip(gt_im_patch, min=0) + gt_im_patch = ratiomap_output.squeeze(0) + + + + + + + return { + 'lq_clean':lq_nonoise, + 'gt': exposures_gt, + 'ratio': torch.tensor(ratio), + 'ratio1': ratio_all, + 'wb': gt_wb if self.which_meta == 'gt' else lq_wb, + 'ccm': gt_ccm if self.which_meta == 'gt' else lq_ccm, + 'lq_path': lq_path, + 'gt_path': gt_path, + 'intact': True + } + + def __len__(self): + return len(self.lq_paths) + + + +@DATASET_REGISTRY.register() +class DenoiserHDRRAWDataset(data.Dataset): + def __init__(self, opt) -> None: + super().__init__() + self.opt = opt + self.engine = self.opt.get('engine', 'torch') + + self.root_folder = opt['dataroot'] + self.postfix = opt.get('postfix', None) + self.which_meta = opt.get('which_meta', 'gt') + self.zero_clip = 0 if opt.get('zero_clip', True) else None + self.ratio_range = opt.get('ratio_range', (100, 300)) + self.use_patches = opt.get('use_patches', False) + self.load_in_mem = opt.get('load_in_mem', False) + if self.use_patches: + self.patch_id_max = opt.get('patch_id_max', 8) + self.patch_tplt = opt.get('patch_tplt', '_s{:03}') + + assert self.postfix == 'npz' + + + self.lq_paths, self.gt_paths = [], [] + with open(opt['data_pair_list'], 'r') as data_pair_list: + pairs = data_pair_list.readlines() + for pair in pairs: + lq, gt = pair.split(' ')[:2] + gt = gt.rstrip('\n') + + if not self.use_patches: + self.lq_paths.append(osp.join(self.root_folder, lq)) + self.gt_paths.append(osp.join(self.root_folder, gt)) + + else: + for i in range(1, 1 + self.patch_id_max): + self.lq_paths.append(osp.join(self.root_folder, self.insert_patch_id(lq, i, self.patch_tplt))) + self.gt_paths.append(osp.join(self.root_folder, self.insert_patch_id(gt, i, self.patch_tplt))) + + if self.load_in_mem: + self.lqs = { + '.'.join([data_path, self.postfix]): \ + self.depack_meta('.'.join([data_path, self.postfix]), self.postfix) + for data_path in tqdm(self.lq_paths, desc='load lq metas in mem...') + } + self.gts = { + '.'.join([data_path, self.postfix]): \ + self.depack_meta('.'.join([data_path, self.postfix]), self.postfix) + for data_path in tqdm(self.gt_paths, desc='load gt metas in mem...') + } + + @staticmethod + def insert_patch_id(path, patch_id, tplt='_s{:03}'): + exts = path.split('.') + base = exts.pop(0) + while exts[0] != 'ARW': + base += '.' + exts.pop(0) + base = base + tplt.format(patch_id) + return base + '.' + '.'.join(exts) + + @staticmethod + def depack_meta(meta_path, postfix='npz', to_tensor=True): + if postfix == 'npz': + meta = np.load(meta_path, allow_pickle=True) + black_level = np.ascontiguousarray(meta['black_level'].copy().astype('float32')) + white_level = np.ascontiguousarray(meta['white_level'].copy().astype('float32')) + im = np.ascontiguousarray(meta['im'].copy().astype('float32')) + wb = np.ascontiguousarray(meta['wb'].copy().astype('float32')) + ccm = np.ascontiguousarray(meta['ccm'].copy().astype('float32')) + meta.close() + elif postfix == None: + raise NotImplementedError + else: + raise NotImplementedError + + if to_tensor: + im = torch.from_numpy(im).float().contiguous() + black_level = torch.from_numpy(black_level).float().contiguous() + white_level = torch.from_numpy(white_level).float().contiguous() + wb = torch.from_numpy(wb).float().contiguous() + ccm = torch.from_numpy(ccm).float().contiguous() + + return (im - black_level) / (white_level - black_level), \ + wb, ccm + + @staticmethod + def depack_meta_gt(meta_path, postfix='npz', to_tensor=True): + if postfix == 'npz': + meta = np.load(meta_path, allow_pickle=True) + black_level = np.ascontiguousarray(meta['black_level'].copy().astype('float32')) + white_level = np.ascontiguousarray(meta['white_level'].copy().astype('float32')) + im = np.ascontiguousarray(meta['im'].copy().astype('float32')) + wb = np.ascontiguousarray(meta['wb'].copy().astype('float32')) + ccm = np.ascontiguousarray(meta['ccm'].copy().astype('float32')) + meta.close() + elif postfix == None: + ## using rawpy + raise NotImplementedError + else: + raise NotImplementedError + + if to_tensor: + im = torch.from_numpy(im).float().contiguous() + black_level = torch.from_numpy(black_level).float().contiguous() + white_level = torch.from_numpy(white_level).float().contiguous() + wb = torch.from_numpy(wb).float().contiguous() + ccm = torch.from_numpy(ccm).float().contiguous() + + return im, \ + wb, ccm + + + def randint(self, *range): + if self.engine == 'torch': + return torch.randint(*range, size=(1,)).item() + else: + return np.random.randint(*range) + + def flip(self, x, dim): + if self.engine == 'torch': + return torch.flip(x, (dim,)) + else: + return np.flip(x, dim) + + def transpose(self, x): + if self.engine == 'torch': + return torch.permute(x, (0, 2, 1)) + else: + return np.transpose(x, (0, 2, 1)) + + def __getitem__(self, index): + + def sum_img_and_noise(img, noises): + for noise in noises: + img += noise + return img + + + lq_path = self.lq_paths[index] + gt_path = self.gt_paths[index] + ratio = self.randint(*self.ratio_range) + + if self.postfix is not None: + lq_path = '.'.join([lq_path, self.postfix]) + gt_path = '.'.join([gt_path, self.postfix]) + + if not self.load_in_mem: + lq_im, lq_wb, lq_ccm = self.depack_meta(lq_path, self.postfix) + gt_im, gt_wb, gt_ccm = self.depack_meta_gt(gt_path, self.postfix) + else: + lq_im, lq_wb, lq_ccm = self.lqs[lq_path] + gt_im, gt_wb, gt_ccm = self.gts[gt_path] + + ## crop + if self.opt['crop_size'] is not None: + _, H, W = lq_im.shape + crop_size = self.opt['crop_size'] + assert crop_size <= H and crop_size <= W + if self.opt['phase'] == 'train': + h_start = torch.randint(0, H - crop_size, (1,)).item() + w_start = torch.randint(0, W - crop_size, (1,)).item() + else: + # center crop + h_start = (H - crop_size) // 2 + w_start = (W - crop_size) // 2 + lq_im_patch = lq_im[:, h_start:h_start+crop_size, w_start:w_start+crop_size] + gt_im_patch = gt_im[:, h_start:h_start+crop_size, w_start:w_start+crop_size] + else: + lq_im_patch = lq_im + gt_im_patch = gt_im + ## flip + rotate + if self.opt['phase'] == 'train': + hflip = self.opt['use_hflip'] and torch.rand((1,)).item() < 0.5 + vflip = self.opt['use_rot'] and torch.rand((1,)).item() < 0.5 + rot90 = self.opt['use_rot'] and torch.rand((1,)).item() < 0.5 + if hflip: + lq_im_patch = torch.flip(lq_im_patch, (2,)) + gt_im_patch = torch.flip(gt_im_patch, (2,)) + if vflip: + lq_im_patch = torch.flip(lq_im_patch, (1,)) + gt_im_patch = torch.flip(gt_im_patch, (1,)) + if rot90: + lq_im_patch = torch.permute(lq_im_patch, (0, 2, 1)) + gt_im_patch = torch.permute(gt_im_patch, (0, 2, 1)) + + if self.opt.get('ratio_aug') is not None: + ratio_range = self.opt['ratio_aug'] + rand_ratio = torch.rand((1,)).item() * (ratio_range[1] - ratio_range[0]) + ratio_range[0] + + gt_im_patch = gt_im_patch / ratio * rand_ratio + ratio = rand_ratio + + + + im_ratio = gt_im_patch.mean() // lq_im_patch.mean() + lq_im_patch = gt_im_patch / im_ratio + ratio_all = im_ratio * ratio + _, H, W = gt_im_patch.shape + + + Highlight_Generator = DenoiserHighlightGenerator(ratio_all) + mask, addmap, lq_im_patch1 = Highlight_Generator.generate_highlight(lq_im_patch) + mask = mask.unsqueeze(0) + + lq_im_patch = lq_im_patch.to(gt_im_patch.device) + lq_im_patch = lq_im_patch * addmap + + map_ratio = (lq_im_patch * ratio_all ) / (gt_im_patch * ratio + 1e-8) + ratiomap_output = torch.mean(map_ratio, dim=0, keepdim=True).unsqueeze(0) * addmap + + + + lq_ev0_non_zero = lq_im_patch * im_ratio + exposures = [ + torch.clip(lq_ev0_non_zero, min=0, max=1), + torch.clip(lq_ev0_non_zero / 200, min=0, max=1), + torch.clip(lq_ev0_non_zero / 120, min=0, max=1), + torch.clip(lq_ev0_non_zero / 60, min=0, max=1), + torch.clip(lq_ev0_non_zero / 32, min=0, max=1), + torch.clip(lq_ev0_non_zero / 16, min=0, max=1), + torch.clip(lq_ev0_non_zero / 8, min=0, max=1), + torch.clip(lq_ev0_non_zero / 4, min=0, max=1), + torch.clip(lq_ev0_non_zero / 2, min=0, max=1), + ] + + exposures_gt = torch.stack(exposures) + lq_nonoise = lq_im_patch * im_ratio + gt_im_patch = torch.clip(gt_im_patch, min=0) + gt_im_patch = ratiomap_output.squeeze(0) + + + + + + + return { + 'lq_clean':lq_nonoise, + 'gt': exposures_gt, + 'ratio': torch.tensor(ratio), + 'ratio1': ratio_all, + 'wb': gt_wb if self.which_meta == 'gt' else lq_wb, + 'ccm': gt_ccm if self.which_meta == 'gt' else lq_ccm, + 'lq_path': lq_path, + 'gt_path': gt_path, + 'intact': True + } + + def __len__(self): + return len(self.lq_paths) \ No newline at end of file diff --git a/ultraled/data/hdr_util.py b/ultraled/data/hdr_util.py new file mode 100644 index 0000000..72758ad --- /dev/null +++ b/ultraled/data/hdr_util.py @@ -0,0 +1,234 @@ +import math +import torch +from torch import nn +from torch.nn.functional import interpolate, pad, conv2d + +class ClipBase(nn.Module): + def __init__(self, clip=False) -> None: + super().__init__() + self._clip_func = torch.clamp if clip else diff_clamp + + def do_clip(self): + self._clip_func = torch.clamp + + def dont_clip(self): + self._clip_func = diff_clamp + + def forward(self, x): + return x + +def _compute_padding(kernel_size): + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def normalize_kernel2d(input): + norm = input.abs().sum(dim=-1).sum(dim=-1) + return input / (norm[..., None, None]) + + +def filter2d(input, kernel, border_type: str = 'reflect'): + # prepare kernel + c = input.shape[-3] + shape = input.shape + tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + padding_shape = _compute_padding([height, width]) + if input.dim() == 5: + padding_shape += [0, 0] + input = pad(input, padding_shape, mode=border_type) + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + out = output.view(*shape) + return out + +def get_laplacian_kernel2d(kernel_size, *, device = None, dtype = torch.float32): + ky, kx = kernel_size + kernel = torch.ones((ky, kx), device=device, dtype=dtype) + mid_x = kx // 2 + mid_y = ky // 2 + kernel[mid_y, mid_x] = 1 - kernel.sum() + return kernel + +def laplacian(input, kernel_size, border_type: str = 'reflect', normalized: bool = True): + kernel = get_laplacian_kernel2d(kernel_size, device=input.device, dtype=input.dtype)[None, ...] + + if normalized: + kernel = normalize_kernel2d(kernel) + + return filter2d(input, kernel, border_type) + + +def rgb_to_grayscale(rgb): + r, g, b = rgb.unbind(dim=-3) + l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(rgb.dtype) + l_img = l_img.unsqueeze(dim=-3) + return l_img + +def get_pyramid_gaussian_kernel(): + return ( + torch.tensor( + [ + [ + [1.0, 4.0, 6.0, 4.0, 1.0], + [4.0, 16.0, 24.0, 16.0, 4.0], + [6.0, 24.0, 36.0, 24.0, 6.0], + [4.0, 16.0, 24.0, 16.0, 4.0], + [1.0, 4.0, 6.0, 4.0, 1.0], + ] + ] + ) + / 256.0 + ) + +def pyrdown(input, border_type: str = 'reflect', align_corners: bool = False, factor: float = 2.0): + kernel = get_pyramid_gaussian_kernel() + channel, height, width = input.shape[-3:] + # blur image + x_blur = filter2d(input, kernel, border_type) + + shape = [int(float(height) / factor), int(float(width) // factor)] + mode = 'bilinear' + if input.dim() == 5: + mode = 'trilinear' + shape = [channel] + shape + + out = interpolate( + x_blur, + size=shape, + mode=mode, + align_corners=align_corners, + ) + return out + +def pyrup(input, shape, border_type: str = 'reflect', align_corners: bool = False): + kernel = get_pyramid_gaussian_kernel() + # upsample tensor + mode = 'bilinear' + if input.dim() == 5: + mode = 'trilinear' + else: + shape = shape[-2:] + x_up = interpolate(input, size=shape, mode=mode, align_corners=align_corners) + + # blurs upsampled tensor + x_blur = filter2d(x_up, kernel, border_type) + return x_blur + +def build_pyramid(input, max_level: int, border_type: str = 'reflect', align_corners: bool = False): + # create empty list and append the original image + pyramid = [] + pyramid.append(input) + + # iterate and downsample + for _ in range(max_level - 1): + img_curr = pyramid[-1] + img_down = pyrdown(img_curr, border_type, align_corners) + pyramid.append(img_down) + + return pyramid + +def build_laplacian_pyramid(input, max_level: int, border_type: str = 'reflect', align_corners: bool = False): + # create gaussian pyramid + gaussian_pyramid = build_pyramid(input, max_level, border_type, align_corners) + laplacian_pyramid = [] + + # iterate and compute difference of adjacent layers in a gaussian pyramid + for i in range(max_level - 1): + img_expand = pyrup(gaussian_pyramid[i + 1], gaussian_pyramid[i].shape[-3:], border_type, align_corners) + laplacian = gaussian_pyramid[i] - img_expand + laplacian_pyramid.append(laplacian) + laplacian_pyramid.append(gaussian_pyramid[-1]) + return laplacian_pyramid + +def diff_clamp(x, _min, _max, k=1e-3): + x = torch.minimum(x - _max, (x - _max) * k) + _max + x = torch.maximum(x - _min, (x - _min) * k) + _min + return x + +def pyramid_collapse(pyramid, depth): + for i in range(depth, 0, -1): + pyramid[i-1] += pyrup(pyramid[i], pyramid[i-1].shape[-3:]) + return pyramid[0] + + +class BlendMertens(nn.Module): + def __init__(self, contrast_weight=1.0, saturation_weight=1.0, exposure_weight=1.0, clip=False) -> None: + super().__init__() + self._clip_func = torch.clamp if clip else diff_clamp + self._contrast_weight = contrast_weight + self._saturation_weight = saturation_weight + self._exposure_weight = exposure_weight + + def do_clip(self): + self._clip_func = torch.clamp + + def dont_clip(self): + self._clip_func = diff_clamp + + def get_weight(self, x): + # contrast + gray_x = rgb_to_grayscale(x) + laplacian_x = laplacian(gray_x, (5, 5)) + c_weight = torch.abs(laplacian_x) + c_weight = c_weight ** self._contrast_weight + + # saturation + s_weight = torch.std(x, -3, keepdim=True) + s_weight = s_weight ** self._saturation_weight + + # exposure + sig = 0.2 + e_weight = torch.exp(-torch.pow(x - 0.5, 2) / (2 * sig * sig)) + r_w, g_w, b_w = torch.chunk(e_weight, 3, -3) + e_weight = r_w * g_w * b_w + e_weight = e_weight ** self._exposure_weight + + return c_weight * s_weight * e_weight + 1e-12 + + def forward(self, *data): + result = 0 + data = torch.stack(data) + weights = self.get_weight(data) + weight_sum = torch.sum(weights, 0, keepdim=True) + weights = weights / weight_sum + + pyramid_depth = min(int(math.log2(512)), int(math.log2(min(data[0].shape[-2:])))) + # pyramids + lps = build_laplacian_pyramid(data, pyramid_depth) + gps = build_pyramid(weights, pyramid_depth) + + # combine pyramids with weights + result_ps = [] + for i in range(pyramid_depth): + r_i = torch.sum(lps[i] * gps[i], 0) + result_ps.append(r_i) + result = self._clip_func(pyramid_collapse(result_ps, pyramid_depth-1), 0, 1) + return result \ No newline at end of file diff --git a/ultraled/data/noise_util_rawhdr.py b/ultraled/data/noise_util_rawhdr.py new file mode 100644 index 0000000..9db67d6 --- /dev/null +++ b/ultraled/data/noise_util_rawhdr.py @@ -0,0 +1,254 @@ +import torch +from torch import nn +import random +from ultraled.data.commom_noise_util import * + + + +class NoiseGenerator(nn.Module): + def __init__(self, camera_params, noise_type, *, engine='torch') -> None: + super().__init__() + self.camera_params = camera_params + self.cameras = list(camera_params.keys()) + print('Current Using Cameras: ', self.cameras) + + if engine == 'numpy': + self.engine = NumpyEngine() + else: + self.engine = TorchEngine() + + self.noise_type = noise_type.lower() + self.read_type = 'TurkeyLambda' if 't' in self.noise_type else \ + ('Gaussian' if 'g' in self.noise_type else None) + + @property + def sample_K(self): + index = self.engine.randint(0, len(self.camera_params)) + self.current_camera = self.cameras[index] + self.current_camera_params = self.camera_params[self.current_camera] + self.current_k_range = [ + self.camera_params[self.current_camera]['Kmin'], + self.camera_params[self.current_camera]['Kmax'] + ] + log_K_max = self.engine.log(self.current_camera_params['Kmax']) + log_K_min = self.engine.log(self.current_camera_params['Kmin']) + log_K = self.engine.uniform(log_K_min, log_K_max) + self.log_K = log_K + return self.engine.exp(log_K) + + + @property + def sample_read_param(self): + slope = self.current_camera_params[self.read_type]['slope'] + bias = self.current_camera_params[self.read_type]['bias'] + sigma = self.current_camera_params[self.read_type]['sigma'] + mu = self.log_K * slope + bias + sample = self.engine.randn() * sigma + mu + return self.engine.exp(sample) + + @property + def sample_turkey_lambda(self): + if self.read_type != 'TurkeyLambda': + return None + index = self.engine.randint(0, len(self.current_camera_params[self.read_type]['lambda'])) + return self.current_camera_params[self.read_type]['lambda'][index] + + @property + def sample_row_param(self): + slope = self.current_camera_params['Row']['slope'] + bias = self.current_camera_params['Row']['bias'] + sigma = self.current_camera_params['Row']['sigma'] + mu = self.log_K * slope + bias + sample = self.engine.randn() * sigma + mu + return self.engine.exp(sample) + + + + @property + def sample_color_bias(self): + count = len(self.current_camera_params['ColorBias']) + i_range = (self.current_k_range[1] - self.current_k_range[0]) / count + index = int((self.engine.exp(self.log_K) - self.current_k_range[0]) // i_range) + index = max(min(index, len(self.current_camera_params['ColorBias']) - 1), 0) + color_bias = self.current_camera_params['ColorBias'][index] + return self.engine.to_engine_type(color_bias).reshape(4, 1, 1) + + + + @torch.no_grad() + # def forward(self, img): + def forward(self, img, *, K=None): + + + if K is not None: + self.sample_K = K + else: + K = self.sample_K + # K = self.sample_K + + noise1 = [] + # possion noise + if 'p' in self.noise_type: + shot_noise = self.engine.shot_noise(img, K) + noise1.append(shot_noise) + # read noise + if 'g' in self.noise_type: + read_noise = self.engine.gaussian_noise(img, self.sample_read_param) + noise1.append(read_noise) + elif 't' in self.noise_type: + read_noise = self.engine.turkey_lambda_noise(img, self.sample_read_param, self.sample_turkey_lambda) + noise1.append(read_noise) + # row noise + if 'r' in self.noise_type: + row_noise = self.engine.row_noise(img, self.sample_row_param) + noise1.append(row_noise) + # quant noise + if 'q' in self.noise_type: + quant_noise = self.engine.quant_noise(img, 1) + noise1.append(quant_noise) + if 'c' in self.noise_type: + noise1.append(self.sample_color_bias.to(img.device)) + + + return img, noise1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + +### Support multiple cameras + +# class NoiseGenerator(nn.Module): +# def __init__(self, camera_params, noise_type, *, engine='torch') -> None: +# super().__init__() +# self.camera_params = camera_params +# self.cameras = list(camera_params.keys()) +# print('Current Using Cameras: ', self.cameras) + +# if engine == 'numpy': +# self.engine = NumpyEngine() +# else: +# self.engine = TorchEngine() + +# self.noise_type = noise_type.lower() +# self.read_type = 'TurkeyLambda' if 't' in self.noise_type else \ +# ('Gaussian' if 'g' in self.noise_type else None) + +# @property +# def sample_K(self): +# index = self.engine.randint(0, len(self.camera_params)) +# self.current_camera = self.cameras[index] +# self.current_camera_params = self.camera_params[self.current_camera] +# self.current_k_range = [ +# self.camera_params[self.current_camera]['Kmin'], +# self.camera_params[self.current_camera]['Kmax'] +# ] +# log_K_max = self.engine.log(self.current_camera_params['Kmax']) +# log_K_min = self.engine.log(self.current_camera_params['Kmin']) +# log_K = self.engine.uniform(log_K_min, log_K_max) +# self.log_K = log_K +# return self.engine.exp(log_K) + +# @sample_K.setter +# def sample_K(self, K): +# assert len(self.camera_params) == 1 +# index = 0 +# self.current_camera = self.cameras[index] +# self.current_camera_params = self.camera_params[self.current_camera] +# self.current_k_range = [ +# self.camera_params[self.current_camera]['Kmin'], +# self.camera_params[self.current_camera]['Kmax'] +# ] +# self.log_K = self.engine.log(K) + +# @property +# def sample_read_param(self): +# slope = self.current_camera_params[self.read_type]['slope'] +# bias = self.current_camera_params[self.read_type]['bias'] +# sigma = self.current_camera_params[self.read_type]['sigma'] +# mu = self.log_K * slope + bias +# sample = self.engine.randn() * sigma + mu +# return self.engine.exp(sample) + +# @property +# def sample_turkey_lambda(self): +# if self.read_type != 'TurkeyLambda': +# return None +# index = self.engine.randint(0, len(self.current_camera_params[self.read_type]['lambda'])) +# return self.current_camera_params[self.read_type]['lambda'][index] + +# @property +# def sample_row_param(self): +# slope = self.current_camera_params['Row']['slope'] +# bias = self.current_camera_params['Row']['bias'] +# sigma = self.current_camera_params['Row']['sigma'] +# mu = self.log_K * slope + bias +# sample = self.engine.randn() * sigma + mu +# return self.engine.exp(sample) + +# @property +# def sample_color_bias(self): +# count = len(self.current_camera_params['ColorBias']) +# i_range = (self.current_k_range[1] - self.current_k_range[0]) / count +# index = int((self.engine.exp(self.log_K) - self.current_k_range[0]) // i_range) +# index = max(min(index, len(self.current_camera_params['ColorBias']) - 1), 0) +# color_bias = self.current_camera_params['ColorBias'][index] +# return self.engine.to_engine_type(color_bias).reshape(4, 1, 1) + +# @torch.no_grad() +# # def forward(self, img): +# def forward(self, img, *, K=None): +# if K is not None: +# self.sample_K = K +# else: +# K = self.sample_K +# # K = self.sample_K + +# noise1 = [] +# # possion noise +# if 'p' in self.noise_type: +# shot_noise = self.engine.shot_noise(img, K) +# noise1.append(shot_noise) +# # read noise +# if 'g' in self.noise_type: +# read_noise = self.engine.gaussian_noise(img, self.sample_read_param) +# noise1.append(read_noise) +# elif 't' in self.noise_type: +# read_noise = self.engine.turkey_lambda_noise(img, self.sample_read_param, self.sample_turkey_lambda) +# noise1.append(read_noise) +# # row noise +# if 'r' in self.noise_type: +# row_noise = self.engine.row_noise(img, self.sample_row_param) +# noise1.append(row_noise) +# # quant noise +# if 'q' in self.noise_type: +# quant_noise = self.engine.quant_noise(img, 1) +# noise1.append(quant_noise) +# if 'c' in self.noise_type: +# noise1.append(self.sample_color_bias) + +# return img, noise1 + diff --git a/ultraled/data/part_enhance.py b/ultraled/data/part_enhance.py new file mode 100644 index 0000000..77ec7cd --- /dev/null +++ b/ultraled/data/part_enhance.py @@ -0,0 +1,206 @@ +import math +import torch +import numpy as np +import random +from scipy.ndimage import binary_erosion, binary_dilation + +class EstimatorHighlightGenerator: + def __init__(self, max_gain=10.0): + self.max_gain = max_gain + + def gaussian_kernel(self, distance, radius, max_gain, boundary_gain): + return max_gain * torch.exp(-(distance ** 2) / (2 * (radius ** 2))) + + def inverse_square_exponential_kernel(self, distance, radius, max_gain, boundary_gain): + epsilon = 1.0 + beta = 1.0 + return (max_gain / ((distance / radius) ** 2 + epsilon)) * torch.exp(-beta * (distance / radius)) + + def random_kernel(self, distance, radius, max_gain, boundary_gain): + kernels = [self.gaussian_kernel, self.inverse_square_exponential_kernel] + kernel = random.choice(kernels) + kernel_mask = kernel(distance, radius, max_gain, boundary_gain) + kernel_mask[kernel_mask < boundary_gain] = boundary_gain + return kernel_mask + + def torch_line(self, mask, x1, y1, x2, y2): + dx = abs(x2 - x1) + dy = abs(y2 - y1) + sx = 1 if x1 < x2 else -1 + sy = 1 if y1 < y2 else -1 + err = dx - dy + + while True: + if 0 <= x1 < mask.shape[1] and 0 <= y1 < mask.shape[0]: + mask[y1, x1] = 1 + if x1 == x2 and y1 == y2: + break + e2 = err * 2 + if e2 > -dy: + err -= dy + x1 += sx + if e2 < dx: + err += dx + y1 += sy + + def generate_highlight(self, tensor): + tensor = tensor.unsqueeze(0) + H, W = tensor.shape[2], tensor.shape[3] + mask = torch.ones((H, W), dtype=torch.uint8) + addmap = torch.zeros((H, W)) + final_image = torch.clone(tensor) + + total_area = 0.0 + max_area = (H * W) / 5.0 + num_regions = torch.randint(1, 11, (1,)).item() + + for _ in range(num_regions): + region_type = torch.randint(0, 2, (1,)).item() + center_x = torch.randint(0, W, (1,)).item() + center_y = torch.randint(0, H, (1,)).item() + + if region_type == 0: + radius = torch.randint(10, 50, (1,)).item() + area = math.pi * (radius ** 2) + else: + num_sides = torch.randint(5, 9, (1,)).item() + radius = torch.randint(10, 50, (1,)).item() + area = num_sides * (radius ** 2) * 0.5 + + if total_area + area > max_area: + continue + + region_gain = torch.rand(1).item() * self.max_gain + region_gain_matrix = torch.ones((H, W)) + + if region_type == 0: + y_grid, x_grid = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij") + distance = ((x_grid - center_x) ** 2 + (y_grid - center_y) ** 2).sqrt() + region_mask = distance <= radius + boundary_gain = min(random.randint(1, 1), self.max_gain) + region_gain_matrix[region_mask] = self.random_kernel( + distance[region_mask], radius, region_gain, boundary_gain + ) + else: + num_sides = max(3, num_sides) + angles = torch.linspace(0, 2 * math.pi, num_sides + 1) + x_offsets = (torch.cos(angles) * radius).int() + y_offsets = (torch.sin(angles) * radius).int() + polygon_points = [(center_x + x_offsets[i], center_y + y_offsets[i]) for i in range(num_sides)] + region_mask = torch.zeros((H, W), dtype=torch.bool) + + for i in range(num_sides): + x1, y1 = polygon_points[i] + x2, y2 = polygon_points[(i + 1) % num_sides] + self.torch_line(region_mask, x1, y1, x2, y2) + + y_grid, x_grid = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij") + distance_from_center = ((x_grid - center_x) ** 2 + (y_grid - center_y) ** 2).sqrt() + boundary_gain = min(random.randint(1, 1), self.max_gain) + region_gain_matrix[region_mask] = self.random_kernel( + distance_from_center[region_mask], radius, region_gain, boundary_gain + ) + + mask[region_mask] = 0 + final_image += region_gain_matrix.unsqueeze(0).unsqueeze(0) * region_mask.float() + region_gain_matrix1 = torch.ones_like(region_gain_matrix) + region_gain_matrix1[region_mask] = region_gain_matrix[region_mask] + addmap = region_gain_matrix1.float() + total_area += area + + final_image = torch.clamp(final_image, 0, 1) + return mask, addmap, final_image + + +class DenoiserHighlightGenerator: + def __init__(self, max_gain=10.0): + self.max_gain = max_gain + + def gaussian_kernel(self, distance, radius, max_gain, boundary_gain): + alpha = random.uniform(0.1, 1.0) + return max_gain * torch.exp(-(distance ** 2) / (alpha * (radius ** 2))) + + def inverse_square_exponential_kernel(self, distance, radius, max_gain, boundary_gain): + epsilon = 1.0 + beta = random.uniform(1.0, math.sqrt(max_gain)) + return (max_gain / ((((distance / radius) ** 2) * beta) + epsilon)) + + def random_kernel(self, distance, radius, max_gain, boundary_gain): + kernels = [self.gaussian_kernel, self.inverse_square_exponential_kernel] + kernel = random.choice(kernels) + kernel_mask = kernel(distance, radius, max_gain, boundary_gain) + kernel_mask[kernel_mask < boundary_gain] = boundary_gain + return kernel_mask + + def torch_line_fill(self, mask, x1, y1, x2, y2): + H, W = mask.shape + x1, x2, y1, y2 = min(x1, x2), min(max(x1, x2), W-1), min(y1, y2), min(max(y1, y2), H-1) + a = random.randint(1, x2-x1+2) + b, c = random.uniform(-1, 1), random.uniform(-1, 1) + + for i in range(x1, x2): + ymin = torch.tensor(min(np.floor(y1 + b * y1 * np.sin(i/a*np.pi)), H-2)).int() + ymax = torch.tensor(min(np.floor(y2 + c * y2 * np.cos(i/a*np.pi)), H-2)).int() + mask[ymin:ymax, i] = 1 + + def generate_highlight(self, tensor): + tensor = tensor.unsqueeze(0) + H, W = tensor.shape[2], tensor.shape[3] + mask = torch.ones((H, W), dtype=torch.uint8) + addmap = torch.zeros((H, W)) + final_image = torch.clone(tensor) + + total_area = 0.0 + max_area = (H * W) / 2.0 + num_regions = torch.randint(1, 11, (1,)).item() + + region_gain_matrix = torch.ones((H, W)) + region_gain_matrix1 = torch.ones_like(region_gain_matrix) + + for _ in range(num_regions): + center_x = torch.randint(0, W, (1,)).item() + center_y = torch.randint(0, H, (1,)).item() + radius = torch.randint(10, 300, (1,)).item() + area = math.pi * (radius ** 2) + + if total_area + area > max_area: + continue + + region_gain = torch.rand(1).item() * self.max_gain + + num_sides = torch.randint(5, 9, (1,)).item() + num_sides = max(3, num_sides) + + x1, x2 = random.randint(0, W), random.randint(0, W) + y1, y2 = random.randint(0, H), random.randint(0, H) + center_x, center_y = random.randint(min(x1, x2), max(x1, x2)), random.randint(min(y1, y2), max(y1, y2)) + + region_mask = torch.zeros((H, W), dtype=torch.bool) + self.torch_line_fill(region_mask, x1, y1, x2, y2) + + region_mask_np = region_mask.detach().cpu().numpy() + for _ in range(random.randint(0, 5)): + if random.choice([True, False]): + region_mask_np = binary_erosion(region_mask_np) + else: + region_mask_np = binary_dilation(region_mask_np) + region_mask = torch.from_numpy(region_mask_np).to(region_mask.device) + + y_grid, x_grid = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij") + distance_from_center = ((x_grid - center_x) ** 2 + (y_grid - center_y) ** 2).sqrt() + + boundary_gain = min(random.randint(1, 1), self.max_gain) + region_gain_matrix[region_mask] = self.random_kernel( + distance_from_center[region_mask], radius, region_gain, boundary_gain + ) + + mask[region_mask] = 0 + final_image += region_gain_matrix.unsqueeze(0).unsqueeze(0) * region_mask.float() + + region_gain_matrix1[region_mask] = region_gain_matrix[region_mask] + addmap = region_gain_matrix1.float() + total_area += area + + final_image = torch.clamp(final_image, 0, 1) + return mask, addmap, final_image + diff --git a/ultraled/data/prefetch_dataloader.py b/ultraled/data/prefetch_dataloader.py new file mode 100644 index 0000000..5088425 --- /dev/null +++ b/ultraled/data/prefetch_dataloader.py @@ -0,0 +1,125 @@ +import queue as Queue +import threading +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Ref: + https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Ref: + https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher(): + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher(): + """CUDA prefetcher. + + Ref: + https://github.com/NVIDIA/apex/issues/304# + + It may consums more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/ultraled/data/raw_utils.py b/ultraled/data/raw_utils.py new file mode 100644 index 0000000..745c3f5 --- /dev/null +++ b/ultraled/data/raw_utils.py @@ -0,0 +1,191 @@ +import numpy as np +import os +import exifread +import argparse +import glob +import time +from copy import deepcopy +import math + + +from torch import nn +import pyiqa +import re + +import cv2 +import rawpy +import torch +import torch.nn.functional as F +from tqdm import tqdm + + + +def read_img(raw_path): + """Read and process raw image.""" + raw = rawpy.imread(raw_path) + raw_vis = raw.raw_image_visible.copy() + raw_pattern = raw.raw_pattern + + # Process black and white levels + black_level = np.array(raw.black_level_per_channel, dtype=np.float32).reshape(1, 4, 1, 1) + white_level = np.array(raw.camera_white_level_per_channel, dtype=np.float32) + + if (white_level == None).any(): + white_level = np.array(raw.white_level, dtype=np.float32) + if white_level.size == 1: + white_level = white_level.repeat(4, 0) + + white_level = white_level.reshape(1, 4, 1, 1) + raw_packed = torch.from_numpy(np.float32(pack_raw_bayer(raw_vis, raw_pattern))[np.newaxis]).contiguous() + black_level = torch.from_numpy(black_level).contiguous() + white_level = torch.from_numpy(white_level).contiguous() + + return raw, raw_pattern, raw_packed, black_level, white_level + + +def postprocess(raw, raw_pattern, im, bl, wl, output_bps = 8): + """Post-process the image to RGB.""" + im = im * (wl - bl) + bl + im = im.numpy()[0] + im = depack_raw_bayer(im, raw_pattern) + + H, W = im.shape + raw.raw_image_visible[:H, :W] = im + rgb = raw.postprocess(use_camera_wb=True, half_size=False, + no_auto_bright=True, output_bps=output_bps) + rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + + return rgb + + +def filter_bilateral(tenIn, intSize, tenSigmas, tenSigmac): + """Bilateral filter implementation.""" + tenSigmas = tenSigmas.view(-1, 1, 1, 1, 1) + tenSigmac = tenSigmac.view(-1, 1, 1, 1, 1) + + # Create coordinate grids + half_size = int(math.floor(0.5 * intSize)) + coords = torch.linspace(-half_size, half_size, intSize, + dtype=tenIn.dtype, device=tenIn.device) + tenHor = coords.view(1, -1).repeat(intSize, 1) + tenVer = coords.view(-1, 1).repeat(1, intSize) + + # Calculate distances + tenDists = (tenHor.square() + tenVer.square()).sqrt().view(1, 1, intSize * intSize, 1, 1) + tenDistc = tenIn.view(tenIn.shape[0], tenIn.shape[1], 1, tenIn.shape[2], tenIn.shape[3]) + + # Apply bilateral filtering + tenOut = F.pad(input=tenIn, pad=[half_size, half_size, half_size, half_size], mode='reflect') + tenOut = F.unfold(input=tenOut, kernel_size=intSize, stride=1, padding=0) + tenOut = tenOut.view(tenIn.shape[0], tenIn.shape[1], intSize * intSize, tenIn.shape[2], tenIn.shape[3]) + + tenWeight = ((-0.5 * tenDists.square() / (tenSigmas.square() + 1e-8)) + + (-0.5 * (tenOut - tenDistc).mean([1], True).square() / (tenSigmac.square() + 1e-8))).exp() + tenWeight = tenWeight / (tenWeight.sum([2], True) + 1e-8) + tenOut = (tenOut * tenWeight).sum([2], False) + + return tenOut + + + +Sony_A7S2_CCM = np.array([[ 1.9712269,-0.6789218, -0.29230508], + [-0.29104823, 1.748401 , -0.45735288], + [ 0.02051281,-0.5380369, 1.5175241 ]], + dtype='float32') + + +def pack_raw_bayer(raw: np.ndarray, raw_pattern: np.ndarray): + #pack Bayer image to 4 channels + R = np.where(raw_pattern==0) + G1 = np.where(raw_pattern==1) + B = np.where(raw_pattern==2) + G2 = np.where(raw_pattern==3) + + raw = raw.astype(np.uint16) + H, W = raw.shape + if H % 2 == 1: + raw = raw[:-1] + if W % 2 == 1: + raw = raw[:, :-1] + out = np.stack((raw[R[0][0]::2, R[1][0]::2], #RGBG + raw[G1[0][0]::2, G1[1][0]::2], + raw[B[0][0]::2, B[1][0]::2], + raw[G2[0][0]::2, G2[1][0]::2]), axis=0).astype(np.uint16) + + return out + + +def depack_raw_bayer(raw: np.ndarray, raw_pattern: np.ndarray): + _, H, W = raw.shape + # raw = raw.astype(np.uint16) + raw = raw.astype(np.float64) + + R = np.where(raw_pattern==0) + G1 = np.where(raw_pattern==1) + B = np.where(raw_pattern==2) + G2 = np.where(raw_pattern==3) + + raw_flatten = np.zeros((H * 2, W * 2)) + raw_flatten[R[0][0]::2, R[1][0]::2] = raw[0] + raw_flatten[G1[0][0]::2, G1[1][0]::2] = raw[1] + raw_flatten[B[0][0]::2, B[1][0]::2] = raw[2] + raw_flatten[G2[0][0]::2, G2[1][0]::2] = raw[3] + + # raw_flatten = raw_flatten.astype(np.uint16) + raw_flatten = raw_flatten.astype(np.float64) + + return raw_flatten + + +def metainfo(rawpath): + with open(rawpath, 'rb') as f: + tags = exifread.process_file(f) + _, suffix = os.path.splitext(os.path.basename(rawpath)) + + if suffix == '.dng': + expo = eval(str(tags['Image ExposureTime'])) + iso = eval(str(tags['Image ISOSpeedRatings'])) + else: + expo = eval(str(tags['EXIF ExposureTime'])) + iso = eval(str(tags['EXIF ISOSpeedRatings'])) + + # print('ISO: {}, ExposureTime: {}'.format(iso, expo)) + return iso, expo + + + +def illuminance_correct(x, y): + x_m = x.mean(dim=(-1, -2)) + y_m = y.mean(dim=(-1, -2)) + xy_m = (x * y).mean(dim=(-1, -2)) + xx_m = (x * x).mean(dim=(-1, -2)) + a = (xy_m - x_m * y_m) / (xx_m - x_m * x_m) + b = y_m - a * x_m + return a.reshape(1, -1, 1, 1) * x + b.reshape(1, -1, 1, 1) + +def resize_image(img, target_shape, is_mask=False): + h, w = target_shape + interpolation = cv2.INTER_NEAREST if is_mask else cv2.INTER_AREA + return cv2.resize(img, (w, h), interpolation=interpolation) + +def load_image(path, target_size=None): + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img is None: + raise FileNotFoundError(f"Cannot read image {path}") + + if len(img.shape) == 2: + img = np.expand_dims(img, axis=-1) + + if img.shape[2] == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + if target_size is not None: + img = cv2.resize(img, (target_size[1], target_size[0]), interpolation=cv2.INTER_AREA) + + return img.astype(np.float32) + +def image_to_tensor(img): + return torch.from_numpy(img).permute(2,0,1).unsqueeze(0) + + + diff --git a/ultraled/losses/__init__.py b/ultraled/losses/__init__.py new file mode 100644 index 0000000..45fcadb --- /dev/null +++ b/ultraled/losses/__init__.py @@ -0,0 +1,31 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from ultraled.utils import get_root_logger, scandir +from ultraled.utils.registry import LOSS_REGISTRY +from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty + +__all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] + +# automatically scan and import loss modules for registry +# scan all the files under the 'losses' folder and collect files ending with '_loss.py' +loss_folder = osp.dirname(osp.abspath(__file__)) +loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] +# import all the loss modules +_model_modules = [importlib.import_module(f'ultraled.losses.{file_name}') for file_name in loss_filenames] + + +def build_loss(opt): + """Build loss from options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + loss_type = opt.pop('type') + loss = LOSS_REGISTRY.get(loss_type)(**opt) + logger = get_root_logger() + logger.info(f'Loss [{loss.__class__.__name__}] is created.') + return loss diff --git a/ultraled/losses/__pycache__/__init__.cpython-38.pyc b/ultraled/losses/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..516a2de0ec97cb3dac1fc89392f3afb77369659e GIT binary patch literal 1387 zcmZWo&5Ij16qht(dpv7w*#moy!TfwzqO=|mig z=mkly3wlPz!`Fb0d^M&n5>hs^x>pF21b_V`p{Q+!-MB zHi$N1Bqn+U>i;;R-*fyB6wh(Ap_Xi^AZz-Pyka~0h80~6uYNq)|9AqJyG7a5+N$E* z_4g_-Yq3zkM-L&S>&^=$IWqLL_Tm^k9uHnI(9#~`d*INBeeJ>B1I0oCF6jm&fP;hs zOtw${9BjYP4YT1T*#zrg+Y)Tep0=NMUP9l$3U-u|HG{&0tLOKy`U&Ds=DIGEy0JxF znMoWcus+whx#cxHfiJl!ikThXO77RyyvP?i!|Rjekxj~lvB|8iY*tin-{3pDY-DoN zSPk#KgrX^LnfSOaMDMD7iU_`6^`9h#NoHDRR*K0u(%5${XzE5|Wn2JLp)nOMEUJ0! zx}48cW(?;p;y$aKk3(px#-?3l#WB1KQFZP8B>gw<2M`VmV0s`(C!#SO(l)(L$2ar| zSX_vsxktikqXV$%J~{^fHo(Fb%6MQJi{x&+@0q|%LL2I(XOTvER_!Ma=j`p+xen)< zQb48ygN{AoI@!6wsYZIfb9j2-dChAjq)rF>1N0PLH|YS|O{zCdXEc&~hcHf(MzSut cz3G`CIA2c literal 0 HcmV?d00001 diff --git a/ultraled/losses/__pycache__/basic_loss.cpython-38.pyc b/ultraled/losses/__pycache__/basic_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55c1d1757499f3441535d280f8b1c53e8495e192 GIT binary patch literal 11680 zcmeHNUu+!5dEbAxw|69ulthWLtTi)jj ze|INRH5H+0f$m}FW@l!<`DVVq^L_i)&`?pr_0h-wy7KW?CFyVJq5I0=;R0@NUX~ccnXW zrHJ~VH6-c-s6THFTO)U6^bF$ZfOSyx455C=8Wr`Dt#1$SNLMB6uyy3FWF4{P)sZ_Y z-Z+rjZjE*rNiov9$d;>m=Gwrf(dznnWJ|V`d+LI8U!CK5wBMJ*0=KP>nEbvH%2&d?Gc9Jolh#IGsQV_b+x~rt zAHo|y{mSX4-KuYN%vSxhZ+m{V=34gY*_q3yJC4sy$F@!rg0^>h$@CgE0c)zg!H4nk zAZ||}s)*{mB13#YpjqiPY^H*ys zo|22|8g8*wM^N#lZLw8H!S;uwKpB&Eh-Av#146v4ggM{kwUtowZgRh@a+-!8L=mcX z+v`Il9I36Ce93LK8a5Z`(i90+Pouu~7CeLAUb^|pwh9`mgwoRzsHjC$NLvo5CS(t_!THw3uTUmZ0_})b~b>p!T8`p+!jX<22h4 z6}wGoBvaDzlG&(U!`*WcRp|02P}ovMwxo(;$raTa_()ooE#1oBEmSn9lY&*eD_3$K zfshp)v>FQa#T&C=LUS9!v*2^^cZb`P7uQX0vqjf89d^Uy4UR??>)tt0 zqbFkHY9w`ReA#hLzkE)NW^v=P%h}=$F>Z=owwFyvz&SQOb!IAskqQ?Z_k51`7ut5M zvD~md=C9Z+>-Jq{w%g7ILxp_{@i)solp{6$~2ltB4Dtdy^PU;0G5BY!9d7+1{&3Q9dtQRV{;rC~*q1W_i9D3IU< z)P8#Wt(KSWct)uv^xZkOd1R`K9@sxYjB~792y=ML=KRZm5gOHMqt)=M)v$QObUOA7 z=PnONswucCf%C$FE@tDTCc~VdQkWClT+Tt?*v>L96N=Q&U!>xfP>>FXr^&nhSKRKniOsALQW%6u+&c<$$Q8T-KhF zcJ0m4cpsjU6v<{2bzd@@eS_ehzh-m$XK(&f%1-$j2!APKD27*)y zMn)kfTe{#Ly*|9eCx^&?jWImgs{ukewUY~sSug}+)M&5blkALvP^Yy>H^=H3# z?DU_Whwr9a_ATgZPMYfhN!lX(2-Rvd!~B|Uw^6B2UON8wKl;kC^K%bA$Hm7msE(Vi z3z<$5Dylpz%sL!1v3Z~*2HLe+8WLNkP_x#Wdv+g%shBFwdUqvb+tX1PidHNYg&Ej~Z&MJyjb`P9 zZ&M^q+T+^{ghTTe-(L9U{Nn8NEcWozK?vmRG#e-5&p>b2>G(E-{c-sOd-LK~gbB@r2_j7&vkN*mFy-J0kS4K3_1{`;E zdNDQ+_Yy17KuX4m&^!D}FQiOq5C?rg;9{j=lnlXDPgH3XfMcXE_?v_Xqk@$G|A>&x z(`oD@$dDm|e+PF$@BQDW=AI7%;IE&HeR@6(R-slefA5l@uN!-SnUJ?%y2KO{^_mYv&2f{~P)=Lhz8GUjzmB&)I)Y ze9*_)BQfcoC~H%M^UL!k-YJF5afLeeiZsc({lh2oHCFk_0^- z?qE24DGQnskDLDi@1?6v@F^R|0gTC@GB#_r9kT}L?D_^bn+=Q2+5WoA*E|v2x?*B7 z6X(y1wOWUpel!93F9_4wd~)6vt~er$Cmuh#_`;=i%PRG$k&cI>OQI5jrJ+Ue}O_L+6;`T0}cs##8 z3J7LGZ$#exTL8GX&{M!12h9l@l)3oSx|=)mDcU_UKAE47NZipbz9wFOgi;`PO@!nq z!rt%e<5`8jkMJ8MEBJK`!Joi|<^-AoLrOVht$Z&t_+t!9bj4UdhKbHl`{?pWFPC=m zhXgfN97Cl7r?$Nn#~JkYg#G)D#zKwI@wm_x`X2~OZ^wDEC?K%;Zt|Rl>L2GQ+|8It zd_P-$jkqtls7_odIFc3-2c?htNT5bR=}(Zp{808~AHnE5$|{mN7=wr^B~wAZ+#E

-uNQN5&w*&Bibyn~)4M|InXeC0G^cvdeN;*$gj3h!GE?4= z-%t?vI?_k1L^{9d@srr&UvMXash4tsWol^axkzy7^3I~@5zLYhMY7|Rj}xi>N-j$C zS9Fn;qrDmqOV@2)v)ewxY{INR2QkT-BCEdMKrpxFHk)n>$uMCyJb%NnBj0pa(KR=a zTU{4Y56eS=WkjAAkwc5xgtfNK+6r@MB)JN=>c~^s!uIrr9$tsPz#2&9w3vO{Y_=U_ zdVc*BjMlA}tG|4Ta(i17Y<;BxU?h1_5-;vfqvB))W9ade<4ZQkLg_xI13+cA-f$eY zgfQK0q|8MiT&z1nxCan|ZRC<5xmN2qNLV7nMqav&lqJnW7gkgCHz>#+r5I`H;yGSv zD{@<+%{1Y?L@Q4wcX@IQvf{u%Bidu)Ic^1X~}BK+i9S42}$3T%zY#{5hLtobf{Md#I7j!^kmp(5|kj{udSOm5;^Gxya6c1D2t}wk(={?c$whTiQoly=Sjw^4mdJ}I=5mQ%e29{ zl4O)bmRpw<75>?7x4XxcBaSNICJXJo2Ka%rjjyPX%tm}2 zHDz1gq0}~bhNA6j8vm{(Z{-4YTiuZ^OYc{t9w<6D}A^gCb_2m>T;&jN>hiUaS$;opM1 zz(LwBM$mU42dE2dcH~|0$>01Vk_JkO@BUUFZt?il0&vR<-11mYepw;McJuG9L@7R& zeDM~ajY+LNoK3#nl?1*2<<97~qMZ4b)4%OG=bio~lZko%R}bL9bU+fugFEn;37-`r8sDi$ zh`0)_(|8YKJ(?xD;hpP6ggAySaYhh%{j!FSdEg}Rm$4OL0Ult>YrCEus!Ppgxey&0 zn$#z^Lb(>oE1|p|8Y#hxNv-OK8Zk&eWgWch5@ki;u`nWX=0^}~=%ZD}T{?dU@5C0C zzl*xCHQ>6dR41D8YgD{J1!-H66X!G&rVD49D?CCMJgyr5uKrh5j{kYSwSlc vjaQ?u2*^UfPM|R2PZS1~q6#CVMfV`p#j`GtzKZ*Jd_RqUx}VkMgJ}CNBSnBo literal 0 HcmV?d00001 diff --git a/ultraled/losses/__pycache__/gan_loss.cpython-38.pyc b/ultraled/losses/__pycache__/gan_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a0560c8b303e5b8c6183220c09e7e636f218a8d GIT binary patch literal 6760 zcmd5>&5zs073a4ot#-X@Z{l^Fv}_vGR%^?4P&AF>CUN~07xuQ;-P8!tf}q8bM4KYz zA?59=+)HgBC!<9#EzlNNT>`i$@(1x&^bq*gOE0;!zc)is(yr4SdMO3Y&O8nuXWqx} zy*K;u@^VAL_30n})xY|NqWpsf^OuRihq%+g#4uB|G}KB!NWTS{OBwWmr@WBP-}6a9(mE>%zLppGYXwj}7ypb;!R zRovxZEm#g#o~rIja6C8`tfF@;P`8!Vi5FOxve{A#+aF~~kNW}oS{$QobVqR~3zOK7 zFj~Iz)9r0<>&C6y+dEsIbm+DD%P0rwb#$0Qn{%~5b@kxbr%G21jKF+acMZ98*9^o$ zU8`JREm$oVXn(30id)Ch){5G#tDAR{H1+?iD#~@%4P)kj$S+glEo~$9^0}k$>i#b0$hLi#ogI-+mbP*qL_4gfl<#BL6;% zyvKgzob5)5pMlvq%7+46?}@XQnN;QUAS}O`USWK_ahcups%s?0_~Q%?cp9@2B(v_QwNQruBev$ zX$`fh%DaKT^lNJ4)e;ceQ_*!u2GfiB8t!Yj(^t{t8ehrP17)HtL7lP(wW{OKNFwtv zLdLxaB}571;Rrsr!{Rg{IJqsxnHbLeoMvV|?=eihpCl1Mg&azFkU&MVwpcde@f=rHeZuJuegP;+%jnW*HO+wfExPe(1)-vty`OLcZ(c@9~VLypiEUFr!8aneDO z;JH5>^5j?Hz^7Q^WC;L?Rj@*~=S%}l3F@@>+p)8q?L`cmc>7$5N*zCf8O`9H=8@f} zS=bTUJPJ_nRUe|sn@M~=^`QhTCGO2$sX%aU0sn{Qpcm@R!k#)W|0Zm@M*f~Vn_wG83){9Bk@tVnTA_wF{zF%ptP+AmgtL|*w=NdF@+NNqGqv!NR5syg17 z8pPqLH#?;M=|y3xGKVe(xpF|!TAOJ5WCE&)&L8DEzXH>Ei*O^>a&;f^{6GuPJ4raA zPn3Q0(GD340r9>WBNiXvX>NcIVm<#wZo;@sa5RJw>3#qss)YLWCbI+0OR%O4h3tUh zi!(@qQN(~AZPq>~$T7=y6aKJbO}e6M$XO7ITz94{qAyCl!uKNUP9Jc~clCQndL zzzaqCmNhd(eg=f1R_@4c&f6ImuPtIxQN!MLeI9&*_fpCrlqLms1EGFhYidn>d~9lr z<;<4OYnaB#Vnwf|-fT4}U5fnyL%P)PJgHgqo1XV*u4kI)^z(_zC5*|`%T<~d1s&Tdoj z&P;Mg5^&me<{3v{S&1Onagr;Fz@F ziB4f>-$K=3y`$_KkB;X?X6EV>^)zKIa;DhYMBhefjg1-Ot;+9bfI-DhB~FR=DWILt zQO_Qk(5e`5@|_p|#>Zl_An`hPdIk4g9scBPlff}K9v2HT! zm--2rM@ly-6MkZ>=)z<=BL4)9F7gg4|E}Rq-$#=nY!LjT^g{Ruv}bytL%CY69q569 zaDfu=fl0K$oS;AqEYK{dTF*&|HGc){w=Mx~d!xwb;aFysM-a3v>h(^-5g7vJ6F_H4 zrM4l%p$1~C2nuxQ5u!vxQ7l0`@G~DJFhViX@;*b(I+w6kxgu4<84vG|L|Tg!HyCvo zzCu|D$0>o}GC!yKI@Q@UZK0H=6*9sy%8)P8TErAu$^fSq-6RS`46C?Q%NCo-_8inD zBQ92edjBZ@jvrG++CiMkX8E^|FcY62GoCmVLmZj(lOCP>l-au6%XAqc>rX~mwUTty z?J`aoH9Um^ZzvZ}QKtvYUP!;Xv;6(Wqqg%dO}^I;JAFvaM?nlO1LSi_kcw0bJW?k} zfPEQcgF}l>5;xAT645pB^?f>`N~tyX0?E(?SC`Btl`( zwhKKS4GMjj>=wr3knN&S@B66KDF(I-w@$I$Bj0pw(G3SZw@xpLOw?_PZmEYkXCa_0 z8vK%1KKd=j!d;Z+hWe(;Kg9TpU$1|D>+j>IA1tI^lqQm23BUOxI-(FmI+ejsC{YU3 zQ1dZ#Up9TnK<^vpKLNV`A?+P_PqJ|`2SOLrYPQO)mZE^NNVy{Io-#AO% zRVt0WA$WN})gD-y0qgjCg&8ME$~iG~n{&@ow=PEXUisiFRJPWWd@GP)2Vf==0!4jG zeO-NBTi1?j{3oC;sEzUl%3FlmK1YWR9Z1*|*AArepfdVR0CvcJ5wq42pK@~p!WW{q zGUI`nTXbUjOoMWi+=QzD`}K0>U-LS6VKhMfvh&=c*%L6Wci~Cb__wABePU2gf^BPNXMP6k*Q&EiHR6U9&q##;O|mFFok}g8!sB7&d}TUgIzgOR z{0&=6DP>ZY8|th2y1Js-$bmHcY|(088WwKbR&DJy&{j0=Vs=aASH+*r)+)b&o)D4* z9|Et!7EslWV@_EM|2EN1P7g2J2=h_e1|m?@-{X|j6bJ@1WW2DXnTrxZCe9xc4+04B o@wu&4DRg09k#)@bltiW*Xs~3Qd*X+}$Wqt)mU{fOz1C>{2Z(aoBme*a literal 0 HcmV?d00001 diff --git a/ultraled/losses/__pycache__/loss_util.cpython-38.pyc b/ultraled/losses/__pycache__/loss_util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeb4bddc249a104c434092295a03295113826de0 GIT binary patch literal 4467 zcmbtXTW{RP73PrKU1}rAiDI;>6F7qwt#(s+CEE>92(}v6b_)lp8^~&bZqTkNa&|?F zOKxXKDKd6n>Yz;@+}EOgl7Kuk|06F&{{dh7q#w|i0=;}^hI{iR?u8O(XNG6aT+VmC zbLJ;YOLYs+Z-4){gU?nh>#x+9eH=92$E!X-5n3UOt+x4PZHBfT+kL0)GAndK_Y zGrhO*s_&r4tVitB9sTf+o(5&)~RzwZR^yf-_XH4?&H7giZBlZ zk5i?%uechBAnHaUY}M-=&n?+gTe?OqE!Ef)iArU2i{Fh!UnJRaq)G*u{?eMrOHxj& z=M9z7h#wEAg_)$_Ye|}jwROJM7k;wF{UqdTD(|nAGw+Hlm&wfJX4)-oRvene%+4qo z&f%Pj`^`$heo)jqT9&3Pyq?H9BFX!OOY;>DrfNEcOIs=yCJQZ=Csi%|xg{5&@bB+! zV5hy4+)sKNnNV3LNJFu)bN9|h9%s^zMYutlgwj8qJd5Ji;G}p>YZD#%R)ro#%B=dL z?J<{m%(LY(+OmcsAajl%s=Ii#dH;mUkin!2@`E8ewa&L4kl0;=1%%)hPZ!6|oX9lL1(63feSW6ze$iBrq=XFMaU1oKA1vU95z6IoF zM~7!3r`r(LgZHUzkuLvDf|qF&CLo#$d80nL=`w} zpZB68VUC}V8L&LJy_DzRo9rM`JQP7H{VbL1XqOx`7U%t3W&FTDB3Z#z)JuFY3Z$Y# zp)L22u+$;>#sCKZ9liuhtn+u@Jv>HLH6OP49>z{orU_7;2oWltr94RcgFG|3=Ace` z`EisTq&WfpginA*`exS+sW|3-7?#JN*F{V?83C&7<;1djO4`YIHD+z)5#apB{YEEk zTx)xN^Zcb7H@3J*Ua~%2!cBwQb**mZY)pjMy8E%;AH-tIG`F|6dDI`IGQ$zUVe}yC ze1B|cO@@>1P;_~GRrk$Vp=qKHjig8FXu!1yzfxMN+4A>QqdA?J%*bz=iCZRKG(Or~ z=hxQx^`AE?aT@rs+HRHNP223DX%`FHf6_Eo zMTNeqCNd~m!`G;wxGld*h33trIkj8pHn*@J>g2MSK^21Bk z$fnGVLa_PkxkQ&woslziD1ejX8PzJLHgYrPkvnuLkHnw%)IF_)3unxf<*C;F9{Bl( zB0~VfeoI&>-c2R%i9~=iqok*K(MJI3`a#C~{(z`==gv+^wD&R&s?wAP9msI_{)ry) zkWwirx=|v+&UkzW;~!2q=!Icj3KfMpmNO04yW_5XksS*#K0zP~GE7D+*pq3FpG@Z2 z0a(TFjYA?)!Xbq_-nf@!NGx?7%O|!g_Z%f*di<(jTY5MhjPQd4K8QXRF_HVe;9;Z! z8TBK?4O$Pf)&y)^9+wu>4z8o~ImN)DhEs_}kO5B4zz>mXB1=3+w6mJ_`M>a-{}aU? zBJ%|j_wr;v@)bZ|I}~Dom@dCgvQ&mECp16^(cKIT3TeQ z_A0w<%PZ&$W(j~&4$1{I)FmneZ{!0uL41{R0_3To1Ax0TU*HTgUqBO&oMDBs&d`N) z<*5@^-?T>645UH5AP{?08&-!k_t3gw4Ly*G=YJRGxQ#1Q9#fJ*pqLnL zigRtf#vPIQQ4F9*%bG#e)O2uO^}SIIWTDweg8jGLxRtY|G&I2Mb3vqE&;Yx@|32j|J56b z85*#hF+Dho{0l6`_fv3)p$1JYak>vGO81dMAAt^WD{Y!n(f+)Lu=Gu!7p_V!DZhc@ z{2y>$Ld zQe^m7(K88+bxBtOy|=>CEYwO82?w(bP3apf-3fHu(;L+aW>%r2rw*Ln4d5q_#Wvj^ T^whWovqi_dSm$2T`^vuonS12D literal 0 HcmV?d00001 diff --git a/ultraled/losses/archive/conditional_loss.py b/ultraled/losses/archive/conditional_loss.py new file mode 100644 index 0000000..dd5cb27 --- /dev/null +++ b/ultraled/losses/archive/conditional_loss.py @@ -0,0 +1,55 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from ultraled.utils.registry import LOSS_REGISTRY + +@LOSS_REGISTRY.register() +class EMAConditionalLoss(nn.Module): + def __init__(self, branch_num, ema_decay=0.999, ema_time=1, core='l1', *, conv_count=1, reduce='mean', loss_weight=1.0, init='eye') -> None: + super().__init__() + if init == 'eye': + gt = torch.eye(branch_num) + elif init == 'rand': + gt = torch.rand((branch_num, branch_num)) + gt = torch.softmax(gt, 1) + elif init == 'randn': + gt = torch.randn((branch_num, branch_num)) + gt = torch.softmax(gt, 1) + # gt = self.label_smoothing(torch.eye(branch_num)) + gt = gt.repeat(1, conv_count) + self.gts = nn.parameter.Parameter(gt, requires_grad=False) + self.ema_decay = ema_decay + self.ema_time = ema_time + self.counter = nn.parameter.Parameter(torch.zeros(branch_num, dtype=torch.int), requires_grad=False) + + assert reduce in ['mean', 'sum'] + reduce_fn = eval(f'torch.{reduce}') + if core == 'l1': + self.loss = lambda x, y: reduce_fn(torch.abs(x - y)) + elif core == 'l2' or core == 'mse': + self.loss = lambda x, y: reduce_fn((x - y) * (x - y)) + elif core == 'cross_entropy': + self.loss = lambda x, y: F.cross_entropy(x, y, reduce=reduce) + else: + raise NotImplementedError() + + self.loss_weight = loss_weight + + def label_smoothing(self, x): + x[x == 1] = 0.8 + x[x == 0] = 0.05 + return x + + def ema(self, x, index): + self.counter[index] += 1 + if self.counter[index] % self.ema_time == 0: + self.counter[index] = 0 + self.gts[index].data.mul_(self.ema_decay).add_(x, alpha=1-self.ema_decay) + print(self.gts) + + def forward(self, x, index): + loss = self.loss(x, self.gts[index]) + if self.ema_decay > 0: + self.ema(x, index) + return loss * self.loss_weight diff --git a/ultraled/losses/basic_loss.py b/ultraled/losses/basic_loss.py new file mode 100644 index 0000000..c0a0970 --- /dev/null +++ b/ultraled/losses/basic_loss.py @@ -0,0 +1,380 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +# from basicsr.archs.vgg_arch import VGGFeatureExtractor +from ultraled.utils.registry import LOSS_REGISTRY +from .loss_util import weighted_loss + +_reduction_modes = ['none', 'mean', 'sum'] + + +@weighted_loss +def l1_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss(pred, target): + return F.mse_loss(pred, target, reduction='none') + + +@weighted_loss +def charbonnier_loss(pred, target, eps=1e-12): + return torch.sqrt((pred - target)**2 + eps) + +@weighted_loss +def raw_loss_sqrt(pred, target): + return F.l1_loss(pred, target, reduction='none') / torch.sqrt(target + 1.0e-8) + +@weighted_loss +def raw_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') / (target + 1.0e-8) + + +@LOSS_REGISTRY.register() +class TVLoss(nn.Module): + """Structure-Aware Total Variation Loss. + + Args: + loss_weight (float): Loss weight for TV loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + alpha (float): Exponent for gradient weighting. Default: 1.2. + lamda (float): Scaling factor for gradient weighting. Default: 1.5. + """ + + def __init__(self, loss_weight=1.0, reduction='mean', alpha=1.2, lamda=1.5): + super(TVLoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + self.alpha = alpha + self.lamda = lamda + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, 3, H, W). Ground truth tensor (RGB image). + weight (Tensor, optional): of shape (N, 1, H, W). Element-wise weights. Default: None. + """ + + weights = torch.tensor([0.2989, 0.2935, 0.1140, 0.2935], device=pred.device).view(1, 4, 1, 1) + I = torch.sum(pred * weights, dim=1, keepdim=True) # (N, 1, H, W) + L = torch.log(I + 0.0001) + dx = L[:, :, :-1, :-1] - L[:, :, :-1, 1:] # (N, 1, H-1, W-1) + dy = L[:, :, :-1, :-1] - L[:, :, 1:, :-1] # (N, 1, H-1, W-1) + + + dx_weight = self.lamda / (torch.abs(dx).pow(self.alpha) + 0.0001) + dy_weight = self.lamda / (torch.abs(dy).pow(self.alpha) + 0.0001) + x_diff = target[:, :, :-1, :-1] - target[:, :, :-1, 1:] # (N, C, H-1, W-1) + y_diff = target[:, :, :-1, :-1] - target[:, :, 1:, :-1] # (N, C, H-1, W-1) + print(x_diff.shape, dx_weight.shape, dx.shape) + x_loss = dx_weight * (x_diff ** 2) + y_loss = dy_weight * (y_diff ** 2) + loss = (x_loss + y_loss) / 2.0 + + # Apply reduction + if self.reduction == 'mean': + loss = loss.mean() + elif self.reduction == 'sum': + loss = loss.sum() + elif self.reduction == 'none': + pass # Keep the original shape + + # Apply weight if provided + if weight is not None: + loss = loss * weight + + return self.loss_weight * loss + + +@LOSS_REGISTRY.register() +class RAWSQRTL1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(RAWSQRTL1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + + return self.loss_weight * raw_loss_sqrt(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class RAWL1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(RAWL1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + + return self.loss_weight * raw_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class MSELoss(nn.Module): + """MSE (L2) loss. + + Args: + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(MSELoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class CharbonnierLoss(nn.Module): + """Charbonnier loss (one variant of Robust L1Loss, a differentiable + variant of L1Loss). + + Described in "Deep Laplacian Pyramid Networks for Fast and Accurate + Super-Resolution". + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + eps (float): A value used to control the curvature near zero. Default: 1e-12. + """ + + def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12): + super(CharbonnierLoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + self.eps = eps + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class WeightedTVLoss(L1Loss): + """Weighted TV loss. + + Args: + loss_weight (float): Loss weight. Default: 1.0. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + if reduction not in ['mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum') + super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction) + + def forward(self, pred, weight=None): + if weight is None: + y_weight = None + x_weight = None + else: + y_weight = weight[:, :, :-1, :] + x_weight = weight[:, :, :, :-1] + + y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight) + x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight) + + loss = x_diff + y_diff + + return loss + + +@LOSS_REGISTRY.register() +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculating losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == 'l2': + self.criterion = torch.nn.L2loss() + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError(f'{criterion} criterion has not been supported.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += torch.norm( + self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( + gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram diff --git a/ultraled/losses/gan_loss.py b/ultraled/losses/gan_loss.py new file mode 100644 index 0000000..d1afc90 --- /dev/null +++ b/ultraled/losses/gan_loss.py @@ -0,0 +1,208 @@ +import math +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F + +from ultraled.utils.registry import LOSS_REGISTRY + + +@LOSS_REGISTRY.register() +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + target_label = self.get_target_label(input, target_is_real) + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +@LOSS_REGISTRY.register() +class MultiScaleGANLoss(GANLoss): + """ + MultiScaleGANLoss accepts a list of predictions + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) + + def forward(self, input, target_is_real, is_disc=False): + """ + The input is a list of tensors, or a list of (a list of tensors) + """ + if isinstance(input, list): + loss = 0 + for pred_i in input: + if isinstance(pred_i, list): + # Only compute GAN loss for the last layer + # in case of multiscale feature matching + pred_i = pred_i[-1] + # Safe operation: 0-dim tensor calling self.mean() does nothing + loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() + loss += loss_tensor + return loss / len(input) + else: + return super().forward(input, target_is_real, is_disc) + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) + grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): + """Calculate gradient penalty for wgan-gp. + + Args: + discriminator (nn.Module): Network for the discriminator. + real_data (Tensor): Real input data. + fake_data (Tensor): Fake input data. + weight (Tensor): Weight tensor. Default: None. + + Returns: + Tensor: A tensor for gradient penalty. + """ + + batch_size = real_data.size(0) + alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) + + # interpolate between real_data and fake_data + interpolates = alpha * real_data + (1. - alpha) * fake_data + interpolates = autograd.Variable(interpolates, requires_grad=True) + + disc_interpolates = discriminator(interpolates) + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + if weight is not None: + gradients = gradients * weight + + gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() + if weight is not None: + gradients_penalty /= torch.mean(weight) + + return gradients_penalty diff --git a/ultraled/losses/loss_util.py b/ultraled/losses/loss_util.py new file mode 100644 index 0000000..fd293ff --- /dev/null +++ b/ultraled/losses/loss_util.py @@ -0,0 +1,145 @@ +import functools +import torch +from torch.nn import functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are 'none', 'mean' and 'sum'. + + Returns: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + else: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean'): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. Default: None. + reduction (str): Same as built-in losses of PyTorch. Options are + 'none', 'mean' and 'sum'. Default: 'mean'. + + Returns: + Tensor: Loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if weight is not specified or reduction is sum, just reduce the loss + if weight is None or reduction == 'sum': + loss = reduce_loss(loss, reduction) + # if reduction is mean, then compute mean over weight region + elif reduction == 'mean': + if weight.size(1) > 1: + weight = weight.sum() + else: + weight = weight.sum() * loss.size(1) + loss = loss.sum() / weight + + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.5000) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, reduction='sum') + tensor(3.) + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction='mean', **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction) + return loss + + return wrapper + + +def get_local_weights(residual, ksize): + """Get local weights for generating the artifact map of LDL. + + It is only called by the `get_refined_artifact_map` function. + + Args: + residual (Tensor): Residual between predicted and ground truth images. + ksize (Int): size of the local window. + + Returns: + Tensor: weight for each pixel to be discriminated as an artifact pixel + """ + + pad = (ksize - 1) // 2 + residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect') + + unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1) + pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1) + + return pixel_level_weight + + +def get_refined_artifact_map(img_gt, img_output, img_ema, ksize): + """Calculate the artifact map of LDL + (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022) + + Args: + img_gt (Tensor): ground truth images. + img_output (Tensor): output images given by the optimizing model. + img_ema (Tensor): output images given by the ema model. + ksize (Int): size of the local window. + + Returns: + overall_weight: weight for each pixel to be discriminated as an artifact pixel + (calculated based on both local and global observations). + """ + + residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True) + residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True) + + patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5) + pixel_level_weight = get_local_weights(residual_sr.clone(), ksize) + overall_weight = patch_level_weight * pixel_level_weight + + overall_weight[residual_sr < residual_ema] = 0 + + return overall_weight diff --git a/ultraled/metrics/__init__.py b/ultraled/metrics/__init__.py new file mode 100644 index 0000000..8c50e94 --- /dev/null +++ b/ultraled/metrics/__init__.py @@ -0,0 +1,20 @@ +from copy import deepcopy + +from ultraled.utils.registry import METRIC_REGISTRY +from .niqe import calculate_niqe +from .psnr_ssim import calculate_psnr, calculate_ssim + +__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] + + +def calculate_metric(data, opt): + """Calculate metric from data and options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + metric_type = opt.pop('type') + metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) + return metric diff --git a/ultraled/metrics/__pycache__/__init__.cpython-38.pyc b/ultraled/metrics/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36b49aa28d82bef955bc33bddbf4baa94b734a95 GIT binary patch literal 752 zcmY*XJ#X7E5G6&~l4U3HP@pTvsDnIp35ueHYXmNu4o*5$3n7TKDuDVZNewRSsa>+P zW7cHsU*g&+KcGve9%ZC)58&Y)?}P8~_~Ga%LSWZ7zdt1&Lcd&aXas^+aP(Vn9C0jB zg;R_j_M}&llmI5&mwq)!2R$ChVHKo76{aD?Lmu$3AnAyYc=QdW5g%WocybT(P>gNJ zg=q7p-2gthSX`}^bGBNXFE6iFAMf$L<~Yw}-pR}eR+pbe?Ae2xR@dq<*1D|J5E@nB z+yGn-oBcB%kUMiY`WZNbw)iXF;Vr(!9N*!cXS`do_3rQ*Z6SB#eaAbpCBWMI@?-u$ zg;v6-GN)_RRFr2Xqgl;q)0(oWbrMAs?6oTNS&v)a?MGQ=6=;AA4M%*|dV5YNMd|5UvvjkpF z$9ula)!X4d_=(gpRR4KCt3+LFx~wi{MrgxwcxLu)@n+UZqcSNt6u52mjIpvV4P!~W zvC~Ig_G$N*bO;UwA&&49tfqk9$7Aa|iT5O*KFmxyQ%apEQIxRUrYD~IPT^j)bmXLE p|LbW3#xf}x>nHc@58Mg7X2Ta%!#gQny5FN6rHMDe6Fl`I@)zqN!T10G literal 0 HcmV?d00001 diff --git a/ultraled/metrics/__pycache__/metric_util.cpython-38.pyc b/ultraled/metrics/__pycache__/metric_util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5ebb80190d8bfe04b7e6f3d228f9b8e36deb793 GIT binary patch literal 1480 zcma)6&1(}u6rb76S7TeN2!37W)~5~3 zCZ0WrV2}L|{8!wQ;HiIr7f-&on>JBYoG`O9llS#|@BQY<)Ks0}c>Df)?|X@{Z*&+R zHXfF7>KnL7CV9YiIDW0b>f1XuXR;{vs`IRxys+k&U;E{JUQ`V;d?0sX+LBx2v=;k(k90W<`G6tk(sO^^Pz- zCHyc>OnWG<#ayo;4xEMnWfBOBV?*3>?)^*0YluIpdH+@;ords6-drtwBYZuK@Bwu7 zV(ypX_4{2h7fKhvO`S!tPH`6innqo7F;|L5Xj?$rp3u=hi4P^oOg!=Hrj!mOOb&n67Gp^hW194oW2CfT_$FUJg~QUAx?`MwIf zY2t?6d84#xdy!Q04_8;_F#+HPN}_-P-_yfyJ28IHjMHpt2)muPZUv2LvdAvf_za)s zHT+F;n86#5B0U|nCk#2ER&ayVV21BmM@X#eL>abcOAAS}gu9ZouOiXtWXeb?B=S3p zWMn0$7t;L_k*Rw4xz$4qiWsylJluyW5K%{Lw{~0bUu{}2@W3}cftf>7&u%uv!ks(M z|JDHGJ^oi^G1mzq*9eow$`NEeLN-8@gtF-4Zl0O1(eerP3(#Zp7$2od+K$2?1xl<_ z$ttcU3z-$hSv9wKdm$@@((eO3k3jw5aG0b$m7!S9<1U*tQ9Es)rUY+v{-93G@GE>0 zs|1(v4t2>rdMbxWKThEisZO6ZL6fSkb1!*GbDc3qxy#&JZK)s0BvALM(3+yJS(A4C F+E2*kf1UsU literal 0 HcmV?d00001 diff --git a/ultraled/metrics/__pycache__/niqe.cpython-38.pyc b/ultraled/metrics/__pycache__/niqe.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f3242e04660aa86afd22425383d6713225bee13 GIT binary patch literal 6824 zcmeHLO>Er86(+gd{~yV+Y|F8m42(3^29A{2u$>649Yu0vBe9(*ZVj^zXliG)OPNcq zhg`{Pa4%7EDuU(~G$#o?1-%tL_0(&xMNd2h=%MwohhB;T>Gy_QX{9(p|Ikx+!QuQq zzIpS0zIi{Y*DD%+KmGi#onJ6b`y(~Re#cs(esk-TvyA`X_ty)#p&U>|P-Kwj4!E1CUtO?YMy!5eVP4Y6Y;C+No z@fxpxtXorJV&~|t#*gwze&l0)TeqgU{(;tJV>8ZUy;_TTrC6<@wJEif3T71JVqD<) zqY&%~ZCXQp^u}DO(@e}OVO=@S#if~ygJC(=R6uvsG z4Xbfp@%rsEYH=m5#lSq}yUE!WYuvA`U!*KyeFx)(T6^Bh|fkz;$KHgkNQcFDc2^+2+ylkQ}X z2jYW|G0kSP9lL&8%})0aGS_F<9`zmC7`sd@BGq$xLN@0Vwe}Z3r{C|cwm<#cpW4=5 ze6#)cFaPrUn``e~eQH1(pXN|V($%L_G%HDQ-E*Qy_=(~7k^*GqL-Q+AwA`KyT8@_# zLVQsQNh=dGd?cg9>9#v*wWO@C<3yGzsVq_H2NG1&k=VT;bZN&)UP(JCq*Bh_@=Hz5fVIE9Hs2L~d(g)Q z<|7eCb}Qgw{@rWK^L;OZjtGtkq%`5Y?>-W9y+Lvm)0AS_PP@%*C1tYR@H7goQqmiG zMSsaSt)JFQ`WbyzpEBruMwicF{D8@q(XnUpze1y}!N6gt+xq?7&=}^1W~}e%u>rF& zVMzKZt!?b;@L_NUZ9LpiWm-4eM%x_bBO}Vi`EB^GW$nZJFko{d5980r1(94k86PF$q-_!pkKT*PI~JdOv=nrEeHhxh zaG>zwziKibG<<=*ol$06Zq#8OjtsA0oxS#Ys==}Z;VBK36ltD*jbGE)@Ay(rOy{v1 zHjC;Il7e^)?d3^6)!ZB?${H3XrVKn!X>GGCr&SN_N**Kn$Ei3$orSRDu18^_yH;IQ z1-F$Z%Of<0R4Xk`((H*=Dj+)(P!pXY5!cElwB>l=v_7d%Lr)k=`F-?1eSo4>Kr#9e zN!pr<4vC5m8&HUcMmQBAHf-x3>9Mw9+&6{=9B*#Nj0>?z-V=V8T1E2aXq6EK;Dckt zj8C9IaXCh;_{3N@;D<^5D%^zH{CQYKG|}Q*rZBaruv5fSit;EyV^<%IsdjBwpFtn! z*RiTLqs4Xd{P6k>(42@WaXC}RNz_ZKertFnZp253leqev79XJ%qt+)zJV~6Xkx4a@ zwK86(;)$JVREwu5@tk1wDnGdOSg#=vd2(4-I)gYGBW$QyK+)UWNKoBrHLx zwMX}(!p1~g0VJq>Y78rAO`?@kt)ns8J4fQmuD&s)X6O@$p2KO(IvP*&I&WajF|?+4 zj=@Bar3MNa9LGBKAif=iEROFQ=sA(Lz?(5Vfp#%Ifw?E58PrcweH*R#1Wane`DGp_ z?Rc#|+^k@?ZoGeu&E6v8a=a|gyx&Iviw5l4rsKi42eQdhN3|Dw))&}K=OG0Hhn*o4 z^$^dz0ej1HeSYQ;nML3622X@c&FN0ZMcAMf@QzU$Ia)#A;|%Vd_gexXtQ&-piXOph4h>r z>j(-MDxz+=JnEcdAtusdF<=N!iap2Mat0xLbrcy#ydM%kCfBzSirYZ|NeGFs+;u?j z9CN_|HqSTR&{^}q#^8V$4v}I}puWuAO&2^;R6iSsc(-hDk=0RtZ|LpS*gGoxgbT!_@yP;KZUqPo(NfbGv=JC*3d&$KzYM3DiNM z!h!CJSGZVreaP;a@==l;eOSqg0n4OD$QjU+0#5i!1R9Jx5!5o;XQ35rehUvvfjSCO zz3=4Y00Re4?O)=goq<|=3#J}yjf-QEU4DFdUn8;!d-@c1gCOc;2=V|s>a_MZtD>Dv zdTDpqtm{Vr!PirB!xl+UYKB6e(3Y)^+v<#TJA}HE=JS1WB=u#n?m%r9*c+GFEK%6A zM1aBo;N8a(M*dEWDajZk*l5bgT{3nO_5X+u1wWoFja$vXsrw4D0J%`8S5%Xt51$Cg ziR+_DI;n>Kq+r2U!JBPW(Z0HR2RG4t2z3=!&G94nUe|%MYfD0g}E zTsTou#9cxNT%c!?9Q;&LKVa)gL~nZ)xhl@5o8@k^sv;1 zJghpgzt?NkAPco2x_~Cd%5!p`SKPTsj*1%f=*XHNA#w zNJVRWT{5Q7Q_`pOiqXh5j2XR=o7Mq7wc6Kaj@HT*&`EOF?=B_!phXFp(XWaM9qT)I z!atyhv<+k=RGEu#6#$CKM|or!s2@0!1L2^rfJ^|PO3Du8@Gjwv3iL8-(FxB=PsXri_uyDzc+l|%S4 zMUejl#Y-3V2`Y+r1>qs(>-QptGTD%guId0wGU$Fs{2rq*9&D870B8jIntkKaIYg$W zf*=9VQ3f)fecS!d@YI38dk2xw9m8Xz$bv+Uy1hON9L*yR9s)=dVfRScDq~3zG6i8R zU?c#wv)AvQC76cyQd(I;&m5b5Yh@Mv2tMf<5-s}ONF!{@T}l!#_~2>1)-2^co55GOAv&!n}*5&1E{BXFv}^9fFk1QJ|jvn znXG4l&w-Kdb_764n$G*C+|HgmB)xrFB=0?hq(rL4pn&Q(K7!K$B-V=Eh!9%`D#UhHtxPBM7 zq&g-fd4u{*;zm-EBt=70i4lZA2?R`v+?75qoJkRRMAcKW?RD1^ww)z#kW}S~xe>TN z07^)g1j<$-(|X~&q_8I8<`L*xo9B_v-;9<558$V9j@Nf&t7+R&(+qdu7%Qsf7$-{CuBm;mA5>zU?!$H0= z(kWrnte@d?;fpA==ckQngn9gp8cKR9c#b2um-Lxj$!H+!fqB4@FX5#M)UQ*WB-}hE ze@yjx6iFV~r#DDZ-c58^$3kRr2Xh1xkmy-u1ZvrdmPrE^)P zG<(}frF;7#w_ezrM+4y;@Sf+a+3S5@-CDyGJDxZ+ONurSlN`t*aZDLhg*Q{UQc6dK sRNEbspGv@~+*|1f0L^_*T%pW$Nb8m0l-O#2&E-<(-254=IJCvj6}9 literal 0 HcmV?d00001 diff --git a/ultraled/metrics/__pycache__/psnr_ssim.cpython-38.pyc b/ultraled/metrics/__pycache__/psnr_ssim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..123f8fc962f5597fdf960df2791794d809712c37 GIT binary patch literal 7307 zcmeHM&2QYs73Xl5+)u4!S&6LJO<2XXw5_~rE1wWj$F^)aZek;l)EL<|*ibvOE6QAQ zGvrE9K=n{3r?x4OOAkc==~5s-k-s5_-g@bo=pk(nhE6RCq$u1&`+LLHYVF950;5R| zp};pE@9i6MIPW*VdHV6tP+r6Ji{E^v@D1cv1C?_Ie)WL%gJVj`CqX@~&1J<45=?%11a`(#qqzcxLT- zS&xUL@Ff?r?zHWeh>M|L->f%QZO;>KnZ-rfTAAH!tTbf36QVtQ?eg;C)%p73}gn(mcI+rk#>v0XF9^{!k@Nngyvdi2YnIi*I8(8XE6plrt$ny4O%lYF|sZ{ z4RgG33$w|Ku*vhG$UF}Hn1=Z+*;Bfv&D(CH>)N5PUR}DrXidK=?3>n-)ADThbm*VH z?mL077VXgS%awfI!nG)xbJl7Yc7nNTRd|&R=cdyU+_5XZY*ndMB`5Wu@5Cb~bxBSV z7Z+qJm{ZQkq0??z(;l~_v^UFh)>Zn*fwkd;tCqCAmau+uX2zO*=9xDtgFl)5vcG9a zzf)gP3(1;xybwdpqYIC2b)2__8(5BK3A?ds364Sire%A0HatwlT3!{HhMt)%h2{An z%-4N)9Y4qK!c}24dQ<86-oTT3oo-m~jT(d!quy8*Fqi5#4ymm~?F7~oR_YX**3|q< zH>N7qCDF9IZa9}H_PvH82=`;anqKjJ7oL~VZTX({y48ygdW+ZRSLWq_s|DK)ME_fc zUFi*ci>B+_7*4GPDFTe040`8;`YcwkYIpE)XbQN~YziqnnBUQ7te90k(j^5|oWjH3 z8NDHWuZ6kUfB!o(6>F*6>G(2Kv)0!n4kz*QlbCdFO~S$CM=UN3{8}=&uIysTVgSc$ z-bMx2$@48LtWW-azFde6^?Y&ZLJ$av<>pF%!)n>ROXEbuth4#Y>A zoP^DtXRB@DwKlt0&sAmK@VThIa`{pfQ!H&)a2SNrX#~|y;K_OrIPFSjGcN5dDp9BY z0(=GJP&;ypW!VTTvZ9`4kFm$}G}>96+M`rs@+7Ppv;pj0^lkc}_h&Q?33@s}FD2I_ zjT_u}570+)53csgW;%hhFSG$T$`3*v-pn3g0%Uu20=QI|UNd?Cw^wKAGkXA+J*r`I z$?|+F#R~vg1JIV<&H&mBUK-c>@HTAc@TtE9yaD?*mwk!&Ul%gK+WG$i-m>rlpWxtN zygU2GOcHE>xf$!J>e(4kBNCzzap<}PwWFG;r|H)*`mA>;fTN^V-(mKlQHtbyE2fDT>K-Z;6H ziOlU>mNRT`J94*y zJi8~*|xI9!!NoZe#!is=M1mYNq( zk^Va5RujgG813289*w6?p1=lNf+X>t=Wn^);kaEN#|g{7E}o}8_8`x@7bfX=DT^eA zNR&trNLCbHmWKv#@DbAW*_9;3lH()@5Mx78d2lWeZjZ-?VmPs(NWVf$VCp?UN#trD zEo%m?-ymjAf|kUB6386jNMy1MMZ%FL?$43ZY?MthISbqSbEHq8JtTI16F3qW63>ni z5<=brGl~r0Ue5m4T|DSPxc5K`^eqxTJQ6-MTl<#u0WBBB!Kc&x05C&~Ys?c2ab- z!l6w`B1#}rq-FQff>?1Jv9_`ia(XCpkr`#iw5@#4Cmp{I^c!W5XWvCf5jE9&T-z!|rEODDy_|g-6LmPM zr^t=(A$ybL6O)Wj&?7Tfk>ycT2nTd)Wz%YQy~f;|eQKmu#+wgEe3C?0k^xfbpL96(2&fk7#?Bgqcxp?jTF6A0-A2{cy_D?u!}~ja ztT)c$o<;IRd5Fyo$K(D6k|^D?B%nd99=t{6ER`8*lOl&qW7}4`q?)Lc-q|>-La%Li z?ukxW#oPQIcR;*D8%i@eR*mS7v&U5eCCg6HpMHiZ%e$hI4vTc4R8jfbVWZO6HW85w z$qJ-U5WA_JH05o#a~!?>9vo`c3b3NGh(w4H4y~1t{Gi`pi(eb&#)O8{iGfHa4eM#x zfQF7oBRq>=BaVBu{YXyyQ7nDVXD^BqzlhHRERqOLd>m}C7rnih-H%g>(*3C3kL2T8 zR075fN5k7k@RnJgYo&JRe*k&AdiY4bU;gGWK@4bqKpcPF+xC}>W%&$NrhJyfb0nT8 zF+~D%sg+Z)iNmSGMV#r_oI74CHe1r~b^;aEa+-`r$Bk1Re`gsM&S?MM++8oukx#7HHb z#_!Qf(Q@! 0]**2)) + gammahat = left_std / right_std + rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2) + rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2) + array_position = np.argmin((r_gam - rhatnorm)**2) + + alpha = gam[array_position] + beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + return (alpha, beta_l, beta_r) + + +def compute_feature(block): + """Compute features. + + Args: + block (ndarray): 2D Image block. + + Returns: + list: Features with length of 18. + """ + feat = [] + alpha, beta_l, beta_r = estimate_aggd_param(block) + feat.extend([alpha, (beta_l + beta_r) / 2]) + + # distortions disturb the fairly regular structure of natural images. + # This deviation can be captured by analyzing the sample distribution of + # the products of pairs of adjacent coefficients computed along + # horizontal, vertical and diagonal orientations. + shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] + for i in range(len(shifts)): + shifted_block = np.roll(block, shifts[i], axis=(0, 1)) + alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block) + # Eq. 8 + mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha)) + feat.extend([alpha, mean, beta_l, beta_r]) + return feat + + +def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + Ref: Making a "Completely Blind" Image Quality Analyzer. + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + Note that we do not include block overlap height and width, since they are + always 0 in the official implementation. + + For good performance, it is advisable by the official implementation to + divide the distorted image in to the same size patched as used for the + construction of multivariate Gaussian model. + + Args: + img (ndarray): Input image whose quality needs to be computed. The + image must be a gray or Y (of YCbCr) image with shape (h, w). + Range [0, 255] with float type. + mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian + model calculated on the pristine dataset. + cov_pris_param (ndarray): Covariance of a pre-defined multivariate + Gaussian model calculated on the pristine dataset. + gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the + image. + block_size_h (int): Height of the blocks in to which image is divided. + Default: 96 (the official recommended value). + block_size_w (int): Width of the blocks in to which image is divided. + Default: 96 (the official recommended value). + """ + assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).') + # crop image + h, w = img.shape + num_block_h = math.floor(h / block_size_h) + num_block_w = math.floor(w / block_size_w) + img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w] + + distparam = [] # dist param is actually the multiscale features + for scale in (1, 2): # perform on two scales (1, 2) + mu = convolve(img, gaussian_window, mode='nearest') + sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu))) + # normalize, as in Eq. 1 in the paper + img_nomalized = (img - mu) / (sigma + 1) + + feat = [] + for idx_w in range(num_block_w): + for idx_h in range(num_block_h): + # process ecah block + block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale, + idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale] + feat.append(compute_feature(block)) + + distparam.append(np.array(feat)) + + if scale == 1: + img = imresize(img / 255., scale=0.5, antialiasing=True) + img = img * 255. + + distparam = np.concatenate(distparam, axis=1) + + # fit a MVG (multivariate Gaussian) model to distorted patch features + mu_distparam = np.nanmean(distparam, axis=0) + # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html + distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)] + cov_distparam = np.cov(distparam_no_nan, rowvar=False) + + # compute niqe quality, Eq. 10 in the paper + invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) + quality = np.matmul( + np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam))) + + quality = np.sqrt(quality) + quality = float(np.squeeze(quality)) + return quality + + +@METRIC_REGISTRY.register() +def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y', **kwargs): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + Ref: Making a "Completely Blind" Image Quality Analyzer. + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + > MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296) + > Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296) + + We use the official params estimated from the pristine dataset. + We use the recommended block size (96, 96) without overlaps. + + Args: + img (ndarray): Input image whose quality needs to be computed. + The input image must be in range [0, 255] with float/int type. + The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order) + If the input order is 'HWC' or 'CHW', it will be converted to gray + or Y (of YCbCr) image according to the ``convert_to`` argument. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the metric calculation. + input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'. + Default: 'HWC'. + convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'. + Default: 'y'. + + Returns: + float: NIQE result. + """ + ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + # we use the official params estimated from the pristine dataset. + niqe_pris_params = np.load(os.path.join(ROOT_DIR, 'niqe_pris_params.npz')) + mu_pris_param = niqe_pris_params['mu_pris_param'] + cov_pris_param = niqe_pris_params['cov_pris_param'] + gaussian_window = niqe_pris_params['gaussian_window'] + + img = img.astype(np.float32) + if input_order != 'HW': + img = reorder_image(img, input_order=input_order) + if convert_to == 'y': + img = to_y_channel(img) + elif convert_to == 'gray': + img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255. + img = np.squeeze(img) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border] + + # round is necessary for being consistent with MATLAB's result + img = img.round() + + niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window) + + return niqe_result diff --git a/ultraled/metrics/psnr_ssim.py b/ultraled/metrics/psnr_ssim.py new file mode 100644 index 0000000..439ad7d --- /dev/null +++ b/ultraled/metrics/psnr_ssim.py @@ -0,0 +1,233 @@ +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +from ultraled.metrics.metric_util import reorder_image, to_y_channel +from ultraled.utils.color_util import rgb2ycbcr_pt +from ultraled.utils.registry import METRIC_REGISTRY + + +@METRIC_REGISTRY.register() +def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 10. * np.log10(255. * 255. / mse) + + +@METRIC_REGISTRY.register() +def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + mse = torch.mean((img - img2)**2, dim=[1, 2, 3]) + return 10. * torch.log10(1. / (mse + 1e-8)) + + +@METRIC_REGISTRY.register() +def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + ssims = [] + for i in range(img.shape[2]): + ssims.append(_ssim(img[..., i], img2[..., i])) + return np.array(ssims).mean() + + +@METRIC_REGISTRY.register() +def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity) (PyTorch version). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + ssim = _ssim_pth(img * 255., img2 * 255.) + return ssim + + +def _ssim(img, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: SSIM result. + """ + + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11 + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) + return ssim_map.mean() + + +def _ssim_pth(img, img2): + """Calculate SSIM (structural similarity) (PyTorch version). + + It is called by func:`calculate_ssim_pt`. + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + + Returns: + float: SSIM result. + """ + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device) + + mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1]) # valid mode + mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1]) # valid mode + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq + sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2 + + cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2) + ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map + return ssim_map.mean([1, 2, 3]) diff --git a/ultraled/metrics/test_metrics/test_psnr_ssim.py b/ultraled/metrics/test_metrics/test_psnr_ssim.py new file mode 100644 index 0000000..eab898b --- /dev/null +++ b/ultraled/metrics/test_metrics/test_psnr_ssim.py @@ -0,0 +1,52 @@ +import cv2 +import torch + +from ultraled.metrics import calculate_psnr, calculate_ssim +from ultraled.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt +from ultraled.utils import img2tensor + + +def test(img_path, img_path2, crop_border, test_y_channel=False): + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED) + + # --------------------- Numpy --------------------- + psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) + ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) + print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}') + + # --------------------- PyTorch (CPU) --------------------- + img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0) + img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0) + + psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) + ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) + print(f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') + + # --------------------- PyTorch (GPU) --------------------- + img = img.cuda() + img2 = img2.cuda() + psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) + ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) + print(f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') + + psnr_pth = calculate_psnr_pt( + torch.repeat_interleave(img, 2, dim=0), + torch.repeat_interleave(img2, 2, dim=0), + crop_border=crop_border, + test_y_channel=test_y_channel) + ssim_pth = calculate_ssim_pt( + torch.repeat_interleave(img, 2, dim=0), + torch.repeat_interleave(img2, 2, dim=0), + crop_border=crop_border, + test_y_channel=test_y_channel) + print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,' + f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}') + + +if __name__ == '__main__': + test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=False) + test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=True) + + test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=False) + test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=True)