diff --git a/datasets/txtfiles/HDR/train.txt b/datasets/txtfiles/HDR/train.txt new file mode 100644 index 0000000..42d0241 --- /dev/null +++ b/datasets/txtfiles/HDR/train.txt @@ -0,0 +1,226 @@ +./LDR/7U6A3450.CR2 ./HDR/train/7U6A3449_hdr +./LDR/7U6A3453.CR2 ./HDR/train/7U6A3452_hdr +./LDR/7U6A3456.CR2 ./HDR/train/7U6A3455_hdr +./LDR/7U6A3459.CR2 ./HDR/train/7U6A3458_hdr +./LDR/7U6A3462.CR2 ./HDR/train/7U6A3461_hdr +./LDR/7U6A3474.CR2 ./HDR/train/7U6A3473_hdr +./LDR/7U6A3477.CR2 ./HDR/train/7U6A3476_hdr +./LDR/7U6A3480.CR2 ./HDR/train/7U6A3479_hdr +./LDR/7U6A3495.CR2 ./HDR/train/7U6A3494_hdr +./LDR/7U6A3513.CR2 ./HDR/train/7U6A3512_hdr +./LDR/7U6A8060.CR2 ./HDR/train/7U6A8059_hdr +./LDR/7U6A8069.CR2 ./HDR/train/7U6A8068_hdr +./LDR/7U6A8075.CR2 ./HDR/train/7U6A8074_hdr +./LDR/7U6A8081.CR2 ./HDR/train/7U6A8080_hdr +./LDR/7U6A8084.CR2 ./HDR/train/7U6A8083_hdr +./LDR/7U6A8087.CR2 ./HDR/train/7U6A8086_hdr +./LDR/7U6A8096.CR2 ./HDR/train/7U6A8095_hdr +./LDR/7U6A8108.CR2 ./HDR/train/7U6A8107_hdr +./LDR/7U6A8111.CR2 ./HDR/train/7U6A8110_hdr +./LDR/7U6A8126.CR2 ./HDR/train/7U6A8125_hdr +./LDR/7U6A8129.CR2 ./HDR/train/7U6A8128_hdr +./LDR/7U6A8135.CR2 ./HDR/train/7U6A8134_hdr +./LDR/7U6A8141.CR2 ./HDR/train/7U6A8140_hdr +./LDR/7U6A8144.CR2 ./HDR/train/7U6A8143_hdr +./LDR/7U6A8159.CR2 ./HDR/train/7U6A8158_hdr +./LDR/7U6A8168.CR2 ./HDR/train/7U6A8167_hdr +./LDR/7U6A8174.CR2 ./HDR/train/7U6A8173_hdr +./LDR/7U6A8177.CR2 ./HDR/train/7U6A8176_hdr +./LDR/7U6A8180.CR2 ./HDR/train/7U6A8179_hdr +./LDR/7U6A8189.CR2 ./HDR/train/7U6A8188_hdr +./LDR/7U6A8192.CR2 ./HDR/train/7U6A8191_hdr +./LDR/7U6A8195.CR2 ./HDR/train/7U6A8194_hdr +./LDR/7U6A8201.CR2 ./HDR/train/7U6A8200_hdr +./LDR/7U6A8207.CR2 ./HDR/train/7U6A8206_hdr +./LDR/7U6A8219.CR2 ./HDR/train/7U6A8218_hdr +./LDR/7U6A8222.CR2 ./HDR/train/7U6A8221_hdr +./LDR/7U6A8225.CR2 ./HDR/train/7U6A8224_hdr +./LDR/7U6A8231.CR2 ./HDR/train/7U6A8230_hdr +./LDR/7U6A8234.CR2 ./HDR/train/7U6A8233_hdr +./LDR/7U6A8240.CR2 ./HDR/train/7U6A8239_hdr +./LDR/7U6A8249.CR2 ./HDR/train/7U6A8248_hdr +./LDR/7U6A8252.CR2 ./HDR/train/7U6A8251_hdr +./LDR/7U6A8258.CR2 ./HDR/train/7U6A8257_hdr +./LDR/7U6A8264.CR2 ./HDR/train/7U6A8263_hdr +./LDR/7U6A8270.CR2 ./HDR/train/7U6A8269_hdr +./LDR/7U6A8273.CR2 ./HDR/train/7U6A8272_hdr +./LDR/7U6A8276.CR2 ./HDR/train/7U6A8275_hdr +./LDR/7U6A8279.CR2 ./HDR/train/7U6A8278_hdr +./LDR/7U6A8285.CR2 ./HDR/train/7U6A8284_hdr +./LDR/7U6A8294.CR2 ./HDR/train/7U6A8293_hdr +./LDR/7U6A8300.CR2 ./HDR/train/7U6A8299_hdr +./LDR/7U6A8303.CR2 ./HDR/train/7U6A8302_hdr +./LDR/7U6A8315.CR2 ./HDR/train/7U6A8314_hdr +./LDR/7U6A8318.CR2 ./HDR/train/7U6A8317_hdr +./LDR/7U6A8321.CR2 ./HDR/train/7U6A8320_hdr +./LDR/7U6A8330.CR2 ./HDR/train/7U6A8329_hdr +./LDR/7U6A8336.CR2 ./HDR/train/7U6A8335_hdr +./LDR/7U6A8339.CR2 ./HDR/train/7U6A8338_hdr +./LDR/7U6A8348.CR2 ./HDR/train/7U6A8347_hdr +./LDR/7U6A8351.CR2 ./HDR/train/7U6A8350_hdr +./LDR/7U6A8354.CR2 ./HDR/train/7U6A8353_hdr +./LDR/7U6A8357.CR2 ./HDR/train/7U6A8356_hdr +./LDR/7U6A8360.CR2 ./HDR/train/7U6A8359_hdr +./LDR/7U6A8366.CR2 ./HDR/train/7U6A8365_hdr +./LDR/7U6A8372.CR2 ./HDR/train/7U6A8371_hdr +./LDR/7U6A8381.CR2 ./HDR/train/7U6A8380_hdr +./LDR/7U6A8384.CR2 ./HDR/train/7U6A8383_hdr +./LDR/7U6A8395.CR2 ./HDR/train/7U6A8394_hdr +./LDR/7U6A8398.CR2 ./HDR/train/7U6A8397_hdr +./LDR/7U6A8404.CR2 ./HDR/train/7U6A8403_hdr +./LDR/7U6A8407.CR2 ./HDR/train/7U6A8406_hdr +./LDR/7U6A8413.CR2 ./HDR/train/7U6A8412_hdr +./LDR/7U6A8416.CR2 ./HDR/train/7U6A8415_hdr +./LDR/7U6A8419.CR2 ./HDR/train/7U6A8418_hdr +./LDR/7U6A8422.CR2 ./HDR/train/7U6A8421_hdr +./LDR/7U6A8428.CR2 ./HDR/train/7U6A8427_hdr +./LDR/7U6A8434.CR2 ./HDR/train/7U6A8433_hdr +./LDR/7U6A8440.CR2 ./HDR/train/7U6A8439_hdr +./LDR/7U6A8443.CR2 ./HDR/train/7U6A8442_hdr +./LDR/7U6A8446.CR2 ./HDR/train/7U6A8445_hdr +./LDR/7U6A8452.CR2 ./HDR/train/7U6A8451_hdr +./LDR/7U6A8455.CR2 ./HDR/train/7U6A8454_hdr +./LDR/7U6A8458.CR2 ./HDR/train/7U6A8457_hdr +./LDR/7U6A8461.CR2 ./HDR/train/7U6A8460_hdr +./LDR/7U6A8473.CR2 ./HDR/train/7U6A8472_hdr +./LDR/7U6A8479.CR2 ./HDR/train/7U6A8478_hdr +./LDR/7U6A8482.CR2 ./HDR/train/7U6A8481_hdr +./LDR/7U6A8491.CR2 ./HDR/train/7U6A8490_hdr +./LDR/7U6A8494.CR2 ./HDR/train/7U6A8493_hdr +./LDR/7U6A8497.CR2 ./HDR/train/7U6A8496_hdr +./LDR/7U6A8500.CR2 ./HDR/train/7U6A8499_hdr +./LDR/7U6A8503.CR2 ./HDR/train/7U6A8502_hdr +./LDR/7U6A8506.CR2 ./HDR/train/7U6A8505_hdr +./LDR/7U6A8509.CR2 ./HDR/train/7U6A8508_hdr +./LDR/7U6A8512.CR2 ./HDR/train/7U6A8511_hdr +./LDR/7U6A8518.CR2 ./HDR/train/7U6A8517_hdr +./LDR/7U6A8521.CR2 ./HDR/train/7U6A8520_hdr +./LDR/7U6A8530.CR2 ./HDR/train/7U6A8529_hdr +./LDR/7U6A8539.CR2 ./HDR/train/7U6A8538_hdr +./LDR/7U6A8542.CR2 ./HDR/train/7U6A8541_hdr +./LDR/7U6A8545.CR2 ./HDR/train/7U6A8544_hdr +./LDR/7U6A8548.CR2 ./HDR/train/7U6A8547_hdr +./LDR/7U6A8551.CR2 ./HDR/train/7U6A8550_hdr +./LDR/7U6A8554.CR2 ./HDR/train/7U6A8553_hdr +./LDR/7U6A8560.CR2 ./HDR/train/7U6A8559_hdr +./LDR/7U6A8575.CR2 ./HDR/train/7U6A8574_hdr +./LDR/7U6A8578.CR2 ./HDR/train/7U6A8577_hdr +./LDR/7U6A8596.CR2 ./HDR/train/7U6A8595_hdr +./LDR/7U6A8599.CR2 ./HDR/train/7U6A8598_hdr +./LDR/7U6A8602.CR2 ./HDR/train/7U6A8601_hdr +./LDR/7U6A8605.CR2 ./HDR/train/7U6A8604_hdr +./LDR/7U6A8609.CR2 ./HDR/train/7U6A8608_hdr +./LDR/7U6A8633.CR2 ./HDR/train/7U6A8632_hdr +./LDR/7U6A8636.CR2 ./HDR/train/7U6A8635_hdr +./LDR/7U6A8642.CR2 ./HDR/train/7U6A8641_hdr +./LDR/7U6A8645.CR2 ./HDR/train/7U6A8644_hdr +./LDR/7U6A8648.CR2 ./HDR/train/7U6A8647_hdr +./LDR/7U6A8654.CR2 ./HDR/train/7U6A8653_hdr +./LDR/7U6A8660.CR2 ./HDR/train/7U6A8659_hdr +./LDR/7U6A8663.CR2 ./HDR/train/7U6A8662_hdr +./LDR/7U6A8678.CR2 ./HDR/train/7U6A8677_hdr +./LDR/7U6A8684.CR2 ./HDR/train/7U6A8683_hdr +./LDR/7U6A8690.CR2 ./HDR/train/7U6A8689_hdr +./LDR/7U6A8693.CR2 ./HDR/train/7U6A8692_hdr +./LDR/7U6A8699.CR2 ./HDR/train/7U6A8698_hdr +./LDR/7U6A8702.CR2 ./HDR/train/7U6A8701_hdr +./LDR/7U6A8705.CR2 ./HDR/train/7U6A8704_hdr +./LDR/7U6A8711.CR2 ./HDR/train/7U6A8710_hdr +./LDR/7U6A8714.CR2 ./HDR/train/7U6A8713_hdr +./LDR/7U6A8720.CR2 ./HDR/train/7U6A8719_hdr +./LDR/7U6A8723.CR2 ./HDR/train/7U6A8722_hdr +./LDR/7U6A8729.CR2 ./HDR/train/7U6A8728_hdr +./LDR/7U6A8735.CR2 ./HDR/train/7U6A8734_hdr +./LDR/7U6A8738.CR2 ./HDR/train/7U6A8737_hdr +./LDR/7U6A8744.CR2 ./HDR/train/7U6A8743_hdr +./LDR/7U6A8747.CR2 ./HDR/train/7U6A8746_hdr +./LDR/7U6A8750.CR2 ./HDR/train/7U6A8749_hdr +./LDR/7U6A8753.CR2 ./HDR/train/7U6A8752_hdr +./LDR/7U6A8762.CR2 ./HDR/train/7U6A8761_hdr +./LDR/7U6A8774.CR2 ./HDR/train/7U6A8773_hdr +./LDR/7U6A8786.CR2 ./HDR/train/7U6A8785_hdr +./LDR/7U6A8789.CR2 ./HDR/train/7U6A8788_hdr +./LDR/7U6A8884.CR2 ./HDR/train/7U6A8883_hdr +./LDR/7U6A8887.CR2 ./HDR/train/7U6A8886_hdr +./LDR/7U6A8896.CR2 ./HDR/train/7U6A8895_hdr +./LDR/7U6A8899.CR2 ./HDR/train/7U6A8898_hdr +./LDR/7U6A8905.CR2 ./HDR/train/7U6A8904_hdr +./LDR/7U6A8908.CR2 ./HDR/train/7U6A8907_hdr +./LDR/7U6A8917.CR2 ./HDR/train/7U6A8916_hdr +./LDR/7U6A8920.CR2 ./HDR/train/7U6A8919_hdr +./LDR/7U6A8926.CR2 ./HDR/train/7U6A8925_hdr +./LDR/7U6A8929.CR2 ./HDR/train/7U6A8928_hdr +./LDR/7U6A8935.CR2 ./HDR/train/7U6A8934_hdr +./LDR/7U6A8938.CR2 ./HDR/train/7U6A8937_hdr +./LDR/7U6A8944.CR2 ./HDR/train/7U6A8943_hdr +./LDR/7U6A8947.CR2 ./HDR/train/7U6A8946_hdr +./LDR/7U6A8953.CR2 ./HDR/train/7U6A8952_hdr +./LDR/7U6A8959.CR2 ./HDR/train/7U6A8958_hdr +./LDR/7U6A8965.CR2 ./HDR/train/7U6A8964_hdr +./LDR/7U6A8968.CR2 ./HDR/train/7U6A8967_hdr +./LDR/7U6A8976.CR2 ./HDR/train/7U6A8975_hdr +./LDR/7U6A8985.CR2 ./HDR/train/7U6A8984_hdr +./LDR/7U6A8991.CR2 ./HDR/train/7U6A8990_hdr +./LDR/7U6A8994.CR2 ./HDR/train/7U6A8993_hdr +./LDR/7U6A8997.CR2 ./HDR/train/7U6A8996_hdr +./LDR/7U6A9006.CR2 ./HDR/train/7U6A9005_hdr +./LDR/7U6A9009.CR2 ./HDR/train/7U6A9008_hdr +./LDR/7U6A9016.CR2 ./HDR/train/7U6A9015_hdr +./LDR/7U6A9022.CR2 ./HDR/train/7U6A9021_hdr +./LDR/7U6A9028.CR2 ./HDR/train/7U6A9027_hdr +./LDR/7U6A9031.CR2 ./HDR/train/7U6A9030_hdr +./LDR/7U6A9061.CR2 ./HDR/train/7U6A9060_hdr +./LDR/7U6A9064.CR2 ./HDR/train/7U6A9063_hdr +./LDR/7U6A9067.CR2 ./HDR/train/7U6A9066_hdr +./LDR/7U6A9079.CR2 ./HDR/train/7U6A9078_hdr +./LDR/7U6A9082.CR2 ./HDR/train/7U6A9081_hdr +./LDR/7U6A9088.CR2 ./HDR/train/7U6A9087_hdr +./LDR/7U6A9091.CR2 ./HDR/train/7U6A9090_hdr +./LDR/7U6A9094.CR2 ./HDR/train/7U6A9093_hdr +./LDR/7U6A9115.CR2 ./HDR/train/7U6A9114_hdr +./LDR/7U6A9121.CR2 ./HDR/train/7U6A9120_hdr +./LDR/7U6A9136.CR2 ./HDR/train/7U6A9135_hdr +./LDR/7U6A9139.CR2 ./HDR/train/7U6A9138_hdr +./LDR/7U6A9142.CR2 ./HDR/train/7U6A9141_hdr +./LDR/7U6A9145.CR2 ./HDR/train/7U6A9144_hdr +./LDR/7U6A9151.CR2 ./HDR/train/7U6A9150_hdr +./LDR/7U6A9160.CR2 ./HDR/train/7U6A9159_hdr +./LDR/7U6A9163.CR2 ./HDR/train/7U6A9162_hdr +./LDR/7U6A9172.CR2 ./HDR/train/7U6A9171_hdr +./LDR/7U6A9181.CR2 ./HDR/train/7U6A9180_hdr +./LDR/7U6A9184.CR2 ./HDR/train/7U6A9183_hdr +./LDR/7U6A9187.CR2 ./HDR/train/7U6A9186_hdr +./LDR/7U6A9190.CR2 ./HDR/train/7U6A9189_hdr +./LDR/7U6A9196.CR2 ./HDR/train/7U6A9195_hdr +./LDR/7U6A9208.CR2 ./HDR/train/7U6A9207_hdr +./LDR/7U6A9214.CR2 ./HDR/train/7U6A9213_hdr +./LDR/7U6A9220.CR2 ./HDR/train/7U6A9219_hdr +./LDR/7U6A9229.CR2 ./HDR/train/7U6A9228_hdr +./LDR/7U6A9238.CR2 ./HDR/train/7U6A9237_hdr +./LDR/7U6A9244.CR2 ./HDR/train/7U6A9243_hdr +./LDR/7U6A9250.CR2 ./HDR/train/7U6A9249_hdr +./LDR/7U6A9253.CR2 ./HDR/train/7U6A9252_hdr +./LDR/7U6A9259.CR2 ./HDR/train/7U6A9258_hdr +./LDR/7U6A9262.CR2 ./HDR/train/7U6A9261_hdr +./LDR/7U6A9265.CR2 ./HDR/train/7U6A9264_hdr +./LDR/7U6A9268.CR2 ./HDR/train/7U6A9267_hdr +./LDR/7U6A9274.CR2 ./HDR/train/7U6A9273_hdr +./LDR/7U6A9277.CR2 ./HDR/train/7U6A9276_hdr +./LDR/7U6A9280.CR2 ./HDR/train/7U6A9279_hdr +./LDR/7U6A9283.CR2 ./HDR/train/7U6A9282_hdr +./LDR/7U6A9304.CR2 ./HDR/train/7U6A9303_hdr +./LDR/7U6A9307.CR2 ./HDR/train/7U6A9306_hdr +./LDR/7U6A9310.CR2 ./HDR/train/7U6A9309_hdr +./LDR/7U6A9313.CR2 ./HDR/train/7U6A9312_hdr +./LDR/7U6A9319.CR2 ./HDR/train/7U6A9318_hdr +./LDR/7U6A9325.CR2 ./HDR/train/7U6A9324_hdr +./LDR/7U6A9331.CR2 ./HDR/train/7U6A9330_hdr +./LDR/7U6A9340.CR2 ./HDR/train/7U6A9339_hdr +./LDR/7U6A9352.CR2 ./HDR/train/7U6A9351_hdr +./LDR/7U6A9370.CR2 ./HDR/train/7U6A9369_hdr +./LDR/7U6A9376.CR2 ./HDR/train/7U6A9375_hdr +./LDR/7U6A9379.CR2 ./HDR/train/7U6A9378_hdr +./LDR/7U6A9382.CR2 ./HDR/train/7U6A9381_hdr +./LDR/7U6A9385.CR2 ./HDR/train/7U6A9384_hdr +./LDR/7U6A9400.CR2 ./HDR/train/7U6A9399_hdr +./LDR/7U6A9403.CR2 ./HDR/train/7U6A9402_hdr diff --git a/datasets/txtfiles/HDR/train1.txt b/datasets/txtfiles/HDR/train1.txt new file mode 100644 index 0000000..6e58fd4 --- /dev/null +++ b/datasets/txtfiles/HDR/train1.txt @@ -0,0 +1,226 @@ +./LDR/7U6A3449.CR2 ./HDR/train/7U6A3449_hdr +./LDR/7U6A3452.CR2 ./HDR/train/7U6A3452_hdr +./LDR/7U6A3455.CR2 ./HDR/train/7U6A3455_hdr +./LDR/7U6A3458.CR2 ./HDR/train/7U6A3458_hdr +./LDR/7U6A3461.CR2 ./HDR/train/7U6A3461_hdr +./LDR/7U6A3473.CR2 ./HDR/train/7U6A3473_hdr +./LDR/7U6A3476.CR2 ./HDR/train/7U6A3476_hdr +./LDR/7U6A3479.CR2 ./HDR/train/7U6A3479_hdr +./LDR/7U6A3494.CR2 ./HDR/train/7U6A3494_hdr +./LDR/7U6A3512.CR2 ./HDR/train/7U6A3512_hdr +./LDR/7U6A8059.CR2 ./HDR/train/7U6A8059_hdr +./LDR/7U6A8068.CR2 ./HDR/train/7U6A8068_hdr +./LDR/7U6A8074.CR2 ./HDR/train/7U6A8074_hdr +./LDR/7U6A8080.CR2 ./HDR/train/7U6A8080_hdr +./LDR/7U6A8083.CR2 ./HDR/train/7U6A8083_hdr +./LDR/7U6A8086.CR2 ./HDR/train/7U6A8086_hdr +./LDR/7U6A8095.CR2 ./HDR/train/7U6A8095_hdr +./LDR/7U6A8107.CR2 ./HDR/train/7U6A8107_hdr +./LDR/7U6A8110.CR2 ./HDR/train/7U6A8110_hdr +./LDR/7U6A8125.CR2 ./HDR/train/7U6A8125_hdr +./LDR/7U6A8128.CR2 ./HDR/train/7U6A8128_hdr +./LDR/7U6A8134.CR2 ./HDR/train/7U6A8134_hdr +./LDR/7U6A8140.CR2 ./HDR/train/7U6A8140_hdr +./LDR/7U6A8143.CR2 ./HDR/train/7U6A8143_hdr +./LDR/7U6A8158.CR2 ./HDR/train/7U6A8158_hdr +./LDR/7U6A8167.CR2 ./HDR/train/7U6A8167_hdr +./LDR/7U6A8173.CR2 ./HDR/train/7U6A8173_hdr +./LDR/7U6A8176.CR2 ./HDR/train/7U6A8176_hdr +./LDR/7U6A8179.CR2 ./HDR/train/7U6A8179_hdr +./LDR/7U6A8188.CR2 ./HDR/train/7U6A8188_hdr +./LDR/7U6A8191.CR2 ./HDR/train/7U6A8191_hdr +./LDR/7U6A8194.CR2 ./HDR/train/7U6A8194_hdr +./LDR/7U6A8200.CR2 ./HDR/train/7U6A8200_hdr +./LDR/7U6A8206.CR2 ./HDR/train/7U6A8206_hdr +./LDR/7U6A8218.CR2 ./HDR/train/7U6A8218_hdr +./LDR/7U6A8221.CR2 ./HDR/train/7U6A8221_hdr +./LDR/7U6A8224.CR2 ./HDR/train/7U6A8224_hdr +./LDR/7U6A8230.CR2 ./HDR/train/7U6A8230_hdr +./LDR/7U6A8233.CR2 ./HDR/train/7U6A8233_hdr +./LDR/7U6A8239.CR2 ./HDR/train/7U6A8239_hdr +./LDR/7U6A8248.CR2 ./HDR/train/7U6A8248_hdr +./LDR/7U6A8251.CR2 ./HDR/train/7U6A8251_hdr +./LDR/7U6A8257.CR2 ./HDR/train/7U6A8257_hdr +./LDR/7U6A8263.CR2 ./HDR/train/7U6A8263_hdr +./LDR/7U6A8269.CR2 ./HDR/train/7U6A8269_hdr +./LDR/7U6A8272.CR2 ./HDR/train/7U6A8272_hdr +./LDR/7U6A8275.CR2 ./HDR/train/7U6A8275_hdr +./LDR/7U6A8278.CR2 ./HDR/train/7U6A8278_hdr +./LDR/7U6A8284.CR2 ./HDR/train/7U6A8284_hdr +./LDR/7U6A8293.CR2 ./HDR/train/7U6A8293_hdr +./LDR/7U6A8299.CR2 ./HDR/train/7U6A8299_hdr +./LDR/7U6A8302.CR2 ./HDR/train/7U6A8302_hdr +./LDR/7U6A8314.CR2 ./HDR/train/7U6A8314_hdr +./LDR/7U6A8317.CR2 ./HDR/train/7U6A8317_hdr +./LDR/7U6A8320.CR2 ./HDR/train/7U6A8320_hdr +./LDR/7U6A8329.CR2 ./HDR/train/7U6A8329_hdr +./LDR/7U6A8335.CR2 ./HDR/train/7U6A8335_hdr +./LDR/7U6A8338.CR2 ./HDR/train/7U6A8338_hdr +./LDR/7U6A8347.CR2 ./HDR/train/7U6A8347_hdr +./LDR/7U6A8350.CR2 ./HDR/train/7U6A8350_hdr +./LDR/7U6A8353.CR2 ./HDR/train/7U6A8353_hdr +./LDR/7U6A8356.CR2 ./HDR/train/7U6A8356_hdr +./LDR/7U6A8359.CR2 ./HDR/train/7U6A8359_hdr +./LDR/7U6A8365.CR2 ./HDR/train/7U6A8365_hdr +./LDR/7U6A8371.CR2 ./HDR/train/7U6A8371_hdr +./LDR/7U6A8380.CR2 ./HDR/train/7U6A8380_hdr +./LDR/7U6A8383.CR2 ./HDR/train/7U6A8383_hdr +./LDR/7U6A8394.CR2 ./HDR/train/7U6A8394_hdr +./LDR/7U6A8397.CR2 ./HDR/train/7U6A8397_hdr +./LDR/7U6A8403.CR2 ./HDR/train/7U6A8403_hdr +./LDR/7U6A8406.CR2 ./HDR/train/7U6A8406_hdr +./LDR/7U6A8412.CR2 ./HDR/train/7U6A8412_hdr +./LDR/7U6A8415.CR2 ./HDR/train/7U6A8415_hdr +./LDR/7U6A8418.CR2 ./HDR/train/7U6A8418_hdr +./LDR/7U6A8421.CR2 ./HDR/train/7U6A8421_hdr +./LDR/7U6A8427.CR2 ./HDR/train/7U6A8427_hdr +./LDR/7U6A8433.CR2 ./HDR/train/7U6A8433_hdr +./LDR/7U6A8439.CR2 ./HDR/train/7U6A8439_hdr +./LDR/7U6A8442.CR2 ./HDR/train/7U6A8442_hdr +./LDR/7U6A8445.CR2 ./HDR/train/7U6A8445_hdr +./LDR/7U6A8451.CR2 ./HDR/train/7U6A8451_hdr +./LDR/7U6A8454.CR2 ./HDR/train/7U6A8454_hdr +./LDR/7U6A8457.CR2 ./HDR/train/7U6A8457_hdr +./LDR/7U6A8460.CR2 ./HDR/train/7U6A8460_hdr +./LDR/7U6A8472.CR2 ./HDR/train/7U6A8472_hdr +./LDR/7U6A8478.CR2 ./HDR/train/7U6A8478_hdr +./LDR/7U6A8481.CR2 ./HDR/train/7U6A8481_hdr +./LDR/7U6A8490.CR2 ./HDR/train/7U6A8490_hdr +./LDR/7U6A8493.CR2 ./HDR/train/7U6A8493_hdr +./LDR/7U6A8496.CR2 ./HDR/train/7U6A8496_hdr +./LDR/7U6A8499.CR2 ./HDR/train/7U6A8499_hdr +./LDR/7U6A8502.CR2 ./HDR/train/7U6A8502_hdr +./LDR/7U6A8505.CR2 ./HDR/train/7U6A8505_hdr +./LDR/7U6A8508.CR2 ./HDR/train/7U6A8508_hdr +./LDR/7U6A8511.CR2 ./HDR/train/7U6A8511_hdr +./LDR/7U6A8517.CR2 ./HDR/train/7U6A8517_hdr +./LDR/7U6A8520.CR2 ./HDR/train/7U6A8520_hdr +./LDR/7U6A8529.CR2 ./HDR/train/7U6A8529_hdr +./LDR/7U6A8538.CR2 ./HDR/train/7U6A8538_hdr +./LDR/7U6A8541.CR2 ./HDR/train/7U6A8541_hdr +./LDR/7U6A8544.CR2 ./HDR/train/7U6A8544_hdr +./LDR/7U6A8547.CR2 ./HDR/train/7U6A8547_hdr +./LDR/7U6A8550.CR2 ./HDR/train/7U6A8550_hdr +./LDR/7U6A8553.CR2 ./HDR/train/7U6A8553_hdr +./LDR/7U6A8559.CR2 ./HDR/train/7U6A8559_hdr +./LDR/7U6A8574.CR2 ./HDR/train/7U6A8574_hdr +./LDR/7U6A8577.CR2 ./HDR/train/7U6A8577_hdr +./LDR/7U6A8595.CR2 ./HDR/train/7U6A8595_hdr +./LDR/7U6A8598.CR2 ./HDR/train/7U6A8598_hdr +./LDR/7U6A8601.CR2 ./HDR/train/7U6A8601_hdr +./LDR/7U6A8604.CR2 ./HDR/train/7U6A8604_hdr +./LDR/7U6A8608.CR2 ./HDR/train/7U6A8608_hdr +./LDR/7U6A8632.CR2 ./HDR/train/7U6A8632_hdr +./LDR/7U6A8635.CR2 ./HDR/train/7U6A8635_hdr +./LDR/7U6A8641.CR2 ./HDR/train/7U6A8641_hdr +./LDR/7U6A8644.CR2 ./HDR/train/7U6A8644_hdr +./LDR/7U6A8647.CR2 ./HDR/train/7U6A8647_hdr +./LDR/7U6A8653.CR2 ./HDR/train/7U6A8653_hdr +./LDR/7U6A8659.CR2 ./HDR/train/7U6A8659_hdr +./LDR/7U6A8662.CR2 ./HDR/train/7U6A8662_hdr +./LDR/7U6A8677.CR2 ./HDR/train/7U6A8677_hdr +./LDR/7U6A8683.CR2 ./HDR/train/7U6A8683_hdr +./LDR/7U6A8689.CR2 ./HDR/train/7U6A8689_hdr +./LDR/7U6A8692.CR2 ./HDR/train/7U6A8692_hdr +./LDR/7U6A8698.CR2 ./HDR/train/7U6A8698_hdr +./LDR/7U6A8701.CR2 ./HDR/train/7U6A8701_hdr +./LDR/7U6A8704.CR2 ./HDR/train/7U6A8704_hdr +./LDR/7U6A8710.CR2 ./HDR/train/7U6A8710_hdr +./LDR/7U6A8713.CR2 ./HDR/train/7U6A8713_hdr +./LDR/7U6A8719.CR2 ./HDR/train/7U6A8719_hdr +./LDR/7U6A8722.CR2 ./HDR/train/7U6A8722_hdr +./LDR/7U6A8728.CR2 ./HDR/train/7U6A8728_hdr +./LDR/7U6A8734.CR2 ./HDR/train/7U6A8734_hdr +./LDR/7U6A8737.CR2 ./HDR/train/7U6A8737_hdr +./LDR/7U6A8743.CR2 ./HDR/train/7U6A8743_hdr +./LDR/7U6A8746.CR2 ./HDR/train/7U6A8746_hdr +./LDR/7U6A8749.CR2 ./HDR/train/7U6A8749_hdr +./LDR/7U6A8752.CR2 ./HDR/train/7U6A8752_hdr +./LDR/7U6A8761.CR2 ./HDR/train/7U6A8761_hdr +./LDR/7U6A8773.CR2 ./HDR/train/7U6A8773_hdr +./LDR/7U6A8785.CR2 ./HDR/train/7U6A8785_hdr +./LDR/7U6A8788.CR2 ./HDR/train/7U6A8788_hdr +./LDR/7U6A8883.CR2 ./HDR/train/7U6A8883_hdr +./LDR/7U6A8886.CR2 ./HDR/train/7U6A8886_hdr +./LDR/7U6A8895.CR2 ./HDR/train/7U6A8895_hdr +./LDR/7U6A8898.CR2 ./HDR/train/7U6A8898_hdr +./LDR/7U6A8904.CR2 ./HDR/train/7U6A8904_hdr +./LDR/7U6A8907.CR2 ./HDR/train/7U6A8907_hdr +./LDR/7U6A8916.CR2 ./HDR/train/7U6A8916_hdr +./LDR/7U6A8919.CR2 ./HDR/train/7U6A8919_hdr +./LDR/7U6A8925.CR2 ./HDR/train/7U6A8925_hdr +./LDR/7U6A8928.CR2 ./HDR/train/7U6A8928_hdr +./LDR/7U6A8934.CR2 ./HDR/train/7U6A8934_hdr +./LDR/7U6A8937.CR2 ./HDR/train/7U6A8937_hdr +./LDR/7U6A8943.CR2 ./HDR/train/7U6A8943_hdr +./LDR/7U6A8946.CR2 ./HDR/train/7U6A8946_hdr +./LDR/7U6A8952.CR2 ./HDR/train/7U6A8952_hdr +./LDR/7U6A8958.CR2 ./HDR/train/7U6A8958_hdr +./LDR/7U6A8964.CR2 ./HDR/train/7U6A8964_hdr +./LDR/7U6A8967.CR2 ./HDR/train/7U6A8967_hdr +./LDR/7U6A8975.CR2 ./HDR/train/7U6A8975_hdr +./LDR/7U6A8984.CR2 ./HDR/train/7U6A8984_hdr +./LDR/7U6A8990.CR2 ./HDR/train/7U6A8990_hdr +./LDR/7U6A8993.CR2 ./HDR/train/7U6A8993_hdr +./LDR/7U6A8996.CR2 ./HDR/train/7U6A8996_hdr +./LDR/7U6A9005.CR2 ./HDR/train/7U6A9005_hdr +./LDR/7U6A9008.CR2 ./HDR/train/7U6A9008_hdr +./LDR/7U6A9015.CR2 ./HDR/train/7U6A9015_hdr +./LDR/7U6A9021.CR2 ./HDR/train/7U6A9021_hdr +./LDR/7U6A9027.CR2 ./HDR/train/7U6A9027_hdr +./LDR/7U6A9030.CR2 ./HDR/train/7U6A9030_hdr +./LDR/7U6A9060.CR2 ./HDR/train/7U6A9060_hdr +./LDR/7U6A9063.CR2 ./HDR/train/7U6A9063_hdr +./LDR/7U6A9066.CR2 ./HDR/train/7U6A9066_hdr +./LDR/7U6A9078.CR2 ./HDR/train/7U6A9078_hdr +./LDR/7U6A9081.CR2 ./HDR/train/7U6A9081_hdr +./LDR/7U6A9087.CR2 ./HDR/train/7U6A9087_hdr +./LDR/7U6A9090.CR2 ./HDR/train/7U6A9090_hdr +./LDR/7U6A9093.CR2 ./HDR/train/7U6A9093_hdr +./LDR/7U6A9114.CR2 ./HDR/train/7U6A9114_hdr +./LDR/7U6A9120.CR2 ./HDR/train/7U6A9120_hdr +./LDR/7U6A9135.CR2 ./HDR/train/7U6A9135_hdr +./LDR/7U6A9138.CR2 ./HDR/train/7U6A9138_hdr +./LDR/7U6A9141.CR2 ./HDR/train/7U6A9141_hdr +./LDR/7U6A9144.CR2 ./HDR/train/7U6A9144_hdr +./LDR/7U6A9150.CR2 ./HDR/train/7U6A9150_hdr +./LDR/7U6A9159.CR2 ./HDR/train/7U6A9159_hdr +./LDR/7U6A9162.CR2 ./HDR/train/7U6A9162_hdr +./LDR/7U6A9171.CR2 ./HDR/train/7U6A9171_hdr +./LDR/7U6A9180.CR2 ./HDR/train/7U6A9180_hdr +./LDR/7U6A9183.CR2 ./HDR/train/7U6A9183_hdr +./LDR/7U6A9186.CR2 ./HDR/train/7U6A9186_hdr +./LDR/7U6A9189.CR2 ./HDR/train/7U6A9189_hdr +./LDR/7U6A9195.CR2 ./HDR/train/7U6A9195_hdr +./LDR/7U6A9207.CR2 ./HDR/train/7U6A9207_hdr +./LDR/7U6A9213.CR2 ./HDR/train/7U6A9213_hdr +./LDR/7U6A9219.CR2 ./HDR/train/7U6A9219_hdr +./LDR/7U6A9228.CR2 ./HDR/train/7U6A9228_hdr +./LDR/7U6A9237.CR2 ./HDR/train/7U6A9237_hdr +./LDR/7U6A9243.CR2 ./HDR/train/7U6A9243_hdr +./LDR/7U6A9249.CR2 ./HDR/train/7U6A9249_hdr +./LDR/7U6A9252.CR2 ./HDR/train/7U6A9252_hdr +./LDR/7U6A9258.CR2 ./HDR/train/7U6A9258_hdr +./LDR/7U6A9261.CR2 ./HDR/train/7U6A9261_hdr +./LDR/7U6A9264.CR2 ./HDR/train/7U6A9264_hdr +./LDR/7U6A9267.CR2 ./HDR/train/7U6A9267_hdr +./LDR/7U6A9273.CR2 ./HDR/train/7U6A9273_hdr +./LDR/7U6A9276.CR2 ./HDR/train/7U6A9276_hdr +./LDR/7U6A9279.CR2 ./HDR/train/7U6A9279_hdr +./LDR/7U6A9282.CR2 ./HDR/train/7U6A9282_hdr +./LDR/7U6A9303.CR2 ./HDR/train/7U6A9303_hdr +./LDR/7U6A9306.CR2 ./HDR/train/7U6A9306_hdr +./LDR/7U6A9309.CR2 ./HDR/train/7U6A9309_hdr +./LDR/7U6A9312.CR2 ./HDR/train/7U6A9312_hdr +./LDR/7U6A9318.CR2 ./HDR/train/7U6A9318_hdr +./LDR/7U6A9324.CR2 ./HDR/train/7U6A9324_hdr +./LDR/7U6A9330.CR2 ./HDR/train/7U6A9330_hdr +./LDR/7U6A9339.CR2 ./HDR/train/7U6A9339_hdr +./LDR/7U6A9351.CR2 ./HDR/train/7U6A9351_hdr +./LDR/7U6A9369.CR2 ./HDR/train/7U6A9369_hdr +./LDR/7U6A9375.CR2 ./HDR/train/7U6A9375_hdr +./LDR/7U6A9378.CR2 ./HDR/train/7U6A9378_hdr +./LDR/7U6A9381.CR2 ./HDR/train/7U6A9381_hdr +./LDR/7U6A9384.CR2 ./HDR/train/7U6A9384_hdr +./LDR/7U6A9399.CR2 ./HDR/train/7U6A9399_hdr +./LDR/7U6A9402.CR2 ./HDR/train/7U6A9402_hdr diff --git a/options/UltraLED/train_step1.yaml b/options/UltraLED/train_step1.yaml new file mode 100644 index 0000000..46350c2 --- /dev/null +++ b/options/UltraLED/train_step1.yaml @@ -0,0 +1,129 @@ +# general settings +name: Train_step1 +model_type: RatioMapEstimatorModel +scale: 1 +num_gpu: 1 +manual_seed: 2022 +metric_in_srgb: false +CRF_path: datasets/EMoR + +# TODO: update the path to the dataroot +# dataset and data loader settings +datasets: + train: + name: EstimatorTrain + type: EstimatorHDRRAWDataset + dataroot: path/to/RAWHDRnpz + postfix: npz + which_meta: gt + data_pair_list: datasets/txtfiles/HDR/train.txt + zero_clip: false + crop_size: ~ + load_in_mem: true + + ratio_range: [1, 31] + noise_type: ptrqc + + + use_patches: false + + use_hflip: true + use_rot: true + crop_size: 512 + load_in_mem: false + + # data loadecr + num_worker_per_gpu: 8 + batch_size_per_gpu: 1 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + +noise_g: + noise_type: ptrqc + camera_params: + SonyA7M4: + Kmin: 0.5176822379876143 + Kmax: 37.48342111254088 + Row: + slope: 0.8498868627654588 + bias: -2.508114778989769 + sigma: 0.2539644785570933 + Gaussian: + slope: 0.9469078551868472 + bias: 0.16045036103138344 + sigma: 0.20521559771812753 + TurkeyLambda: + slope: 0.8741245752046553 + bias: -0.7834371464250983 + sigma: 0.24359869793646144 + lambda: [-0.09686645120382309, -0.20266661047935486, -0.21817433834075928, -0.2772795557975769, -0.26944655179977417, -0.271323561668396, -0.2704177498817444, -0.2713994085788727, -0.27039051055908203, -0.2710208296775818, -0.2720312476158142, -0.2751300632953644, -0.2763080298900604, -0.27588126063346863, -0.2749999165534973, -0.27548256516456604, -0.27511167526245117, -0.276935875415802, -0.28096914291381836, -0.2802513539791107, -0.28123438358306885, -0.28208523988723755, -0.28266826272010803, -0.2647216022014618, -0.26392486691474915, -0.26441627740859985, -0.2643108367919922, -0.2627350389957428, -0.26077425479888916, -0.2635541260242462] + ColorBias: [[0.23521856874227523, 0.22085916072130204, 0.2863488158583641, 0.22595973789691925], [0.3256998673081398, 0.3030224359035492, 0.34010037809610366, 0.3117868521809578], [0.2417914468050003, 0.23460582494735718, 0.2810402739048004, 0.2346685606241226], [0.2893064874410629, 0.2924955692887306, 0.3213977232575417, 0.32337920874357223], [0.25456790179014205, 0.2652137302979827, 0.304450766146183, 0.29432806588709354], [0.1782449632883072, 0.30211541056632996, 0.2497684806585312, 0.24955467879772186], [0.43470048904418945, 0.4313555061817169, 0.34231114387512207, 0.35701221227645874], [0.34076905250549316, -0.01008229423314333, 0.42228585481643677, 0.144450843334198], [0.09368899464607239, 0.10415662825107574, 0.3285943269729614, 0.05827445909380913], [0.13945339620113373, 0.23084817826747894, 0.13442949950695038, 0.28657567501068115], [0.252593609985197, 0.2493764951825142, 0.612515584230423, 0.6146858245134353], [0.16977471113204956, -0.08632060885429382, 0.5320069193840027, 0.3316039741039276], [0.3789997398853302, 0.4710197448730469, 0.5400363206863403, 0.7923370003700256], [0.3004434108734131, 0.2512214779853821, 0.7026330232620239, 0.65498948097229], [0.3602710623666644, 0.3627533086016774, 0.8907627087831497, 1.0329724252223969], [-0.18567052483558655, 0.4224573075771332, 0.8947332501411438, 1.0053366422653198], [0.6430071445275098, 0.3353357759770006, 2.0499183797836302, 1.9346534669399262], [0.960235595703125, -0.4397820830345154, 1.5321524143218994, 1.1829473972320557]] + engine: torch + + +# network structures +network_g: + type: UNetArch + inchannels: 4 + outchannels: 1 + channels: 32 + +network_d: + type: UNetArch + inchannels: 4 + outchannels: 1 + channels: 32 + +# path to the pretrained models +network_d_path: ~ + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +train: + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.999] + + scheduler: + type: HandieLR + milestones: [20000, 24000] + lrs: [9.0e-5, 8.0e-5] + + total_iter: 25000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: RAWL1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 9999999999999 + save_img: false + calculate_metric_in_batch: true + illumination_correct: true + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 2 + test_y_channel: false + ssim: # metric name, can be arbitrary + type: calculate_ssim + crop_border: 2 + test_y_channel: false + +# logging settings +logger: + print_freq: 200 + save_checkpoint_freq: !!float 1288 + use_tb_logger: true \ No newline at end of file diff --git a/options/UltraLED/train_step2.yaml b/options/UltraLED/train_step2.yaml new file mode 100644 index 0000000..f1cc1d1 --- /dev/null +++ b/options/UltraLED/train_step2.yaml @@ -0,0 +1,131 @@ +# general settings +name: Train_step2 +model_type: RAWDenoiserModel +scale: 1 +num_gpu: 1 +manual_seed: 2022 +metric_in_srgb: false +CRF_path: datasets/EMoR + +# TODO: update the path to the dataroot +# dataset and data loader settings +datasets: + train: + name: DenoiserTrain + type: DenoiserHDRRAWDataset + dataroot: path/to/RAWHDRnpz + postfix: npz + which_meta: gt + data_pair_list: datasets/txtfiles/HDR/train.txt + zero_clip: false + crop_size: ~ + load_in_mem: true + + ratio_range: [1, 31] + noise_type: ptrqc + + + use_patches: false + + use_hflip: true + use_rot: true + crop_size: 512 + load_in_mem: false + + # data loadecr + num_worker_per_gpu: 10 + batch_size_per_gpu: 8 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + +noise_g: + noise_type: ptrqc + camera_params: + SonyA7M4: + Kmin: 0.5176822379876143 + Kmax: 37.48342111254088 + Row: + slope: 0.8498868627654588 + bias: -2.508114778989769 + sigma: 0.2539644785570933 + Gaussian: + slope: 0.9469078551868472 + bias: 0.16045036103138344 + sigma: 0.20521559771812753 + TurkeyLambda: + slope: 0.8741245752046553 + bias: -0.7834371464250983 + sigma: 0.24359869793646144 + lambda: [-0.09686645120382309, -0.20266661047935486, -0.21817433834075928, -0.2772795557975769, -0.26944655179977417, -0.271323561668396, -0.2704177498817444, -0.2713994085788727, -0.27039051055908203, -0.2710208296775818, -0.2720312476158142, -0.2751300632953644, -0.2763080298900604, -0.27588126063346863, -0.2749999165534973, -0.27548256516456604, -0.27511167526245117, -0.276935875415802, -0.28096914291381836, -0.2802513539791107, -0.28123438358306885, -0.28208523988723755, -0.28266826272010803, -0.2647216022014618, -0.26392486691474915, -0.26441627740859985, -0.2643108367919922, -0.2627350389957428, -0.26077425479888916, -0.2635541260242462] + ColorBias: [[0.23521856874227523, 0.22085916072130204, 0.2863488158583641, 0.22595973789691925], [0.3256998673081398, 0.3030224359035492, 0.34010037809610366, 0.3117868521809578], [0.2417914468050003, 0.23460582494735718, 0.2810402739048004, 0.2346685606241226], [0.2893064874410629, 0.2924955692887306, 0.3213977232575417, 0.32337920874357223], [0.25456790179014205, 0.2652137302979827, 0.304450766146183, 0.29432806588709354], [0.1782449632883072, 0.30211541056632996, 0.2497684806585312, 0.24955467879772186], [0.43470048904418945, 0.4313555061817169, 0.34231114387512207, 0.35701221227645874], [0.34076905250549316, -0.01008229423314333, 0.42228585481643677, 0.144450843334198], [0.09368899464607239, 0.10415662825107574, 0.3285943269729614, 0.05827445909380913], [0.13945339620113373, 0.23084817826747894, 0.13442949950695038, 0.28657567501068115], [0.252593609985197, 0.2493764951825142, 0.612515584230423, 0.6146858245134353], [0.16977471113204956, -0.08632060885429382, 0.5320069193840027, 0.3316039741039276], [0.3789997398853302, 0.4710197448730469, 0.5400363206863403, 0.7923370003700256], [0.3004434108734131, 0.2512214779853821, 0.7026330232620239, 0.65498948097229], [0.3602710623666644, 0.3627533086016774, 0.8907627087831497, 1.0329724252223969], [-0.18567052483558655, 0.4224573075771332, 0.8947332501411438, 1.0053366422653198], [0.6430071445275098, 0.3353357759770006, 2.0499183797836302, 1.9346534669399262], [0.960235595703125, -0.4397820830345154, 1.5321524143218994, 1.1829473972320557]] + engine: torch + +# network structures +network_g: + type: CUNetArch + inchannels: 4 + outchannels: 4 + channels: 32 + +network_d: + type: UNetArch + inchannels: 4 + outchannels: 1 + channels: 32 + +# TODO: update the path to the pretrained ratio map estimator models +# path to the pretrained ratio map estimator models +network_d_path: path/to/the_pretrained_ratio_map_estimator_model + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.999] + + scheduler: + type: CosineAnnealingRestartLR + eta_min: !!float 1e-5 + periods: [96600, 193200, 289800] + restart_weights: [1, 0.5, 0.25] + + total_iter: 270480 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# validation settings +val: + val_freq: !!float 9999999999999 + save_img: false + calculate_metric_in_batch: true + illumination_correct: true + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 2 + test_y_channel: false + ssim: # metric name, can be arbitrary + type: calculate_ssim + crop_border: 2 + test_y_channel: false + +# logging settings +logger: + print_freq: 200 + save_checkpoint_freq: !!float 1288 + use_tb_logger: true diff --git a/options/base/network_g/cunet.yaml b/options/base/network_g/cunet.yaml new file mode 100644 index 0000000..d41faf5 --- /dev/null +++ b/options/base/network_g/cunet.yaml @@ -0,0 +1,5 @@ +network_g: + type: CUNetArch + inchannels: 4 + outchannels: 4 + channels: 32 \ No newline at end of file diff --git a/options/base/network_g/unet41.yaml b/options/base/network_g/unet41.yaml new file mode 100644 index 0000000..5d42d6c --- /dev/null +++ b/options/base/network_g/unet41.yaml @@ -0,0 +1,5 @@ +network_g: + type: UNetArch + inchannels: 4 + outchannels: 1 + channels: 32 \ No newline at end of file diff --git a/scripts/image_process_ultraled.py b/scripts/image_process_ultraled.py new file mode 100644 index 0000000..f31c133 --- /dev/null +++ b/scripts/image_process_ultraled.py @@ -0,0 +1,137 @@ +import argparse +import glob +import os +import time +from copy import deepcopy +import math + +import cv2 +import numpy as np +import rawpy +import torch +import torch.nn.functional as F +from tqdm import tqdm +import sys + +from ultraled.archs import build_network +from ultraled.utils.options import yaml_load +from ultraled.data.raw_utils import * + + +def load_network(net, load_path, strict = True, param_key = 'params'): + """Load network weights from checkpoint.""" + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + print('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + + print(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') + + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + + net.load_state_dict(load_net, strict=strict) + + +def get_available_device(): + """Get available computing device.""" + if torch.cuda.is_available(): + return torch.device('cuda') + if torch.backends.mps.is_available(): + return torch.device('mps') + return torch.device('cpu') + + + +def setup_network(network_options, pretrained_path): + """Setup and load network.""" + print('Building network...') + print(network_options) + network = build_network(yaml_load(network_options)['network_g']) + + print('Loading checkpoint...') + load_network(network, pretrained_path) + + device = get_available_device() + return network.to(device) + + +@torch.no_grad() +def image_process(): + """Main image processing pipeline.""" + parser = argparse.ArgumentParser(description='Image processing pipeline') + parser.add_argument('-p', '--pretrained_network', type=str, required=True, + help='the pretrained ratio map estimator network path.') + parser.add_argument('-pd', '--pretrained_denosing_network', type=str, required=True, + help='the pretrained network path for denoising.') + parser.add_argument('--data_path', type=str, required=True, + help='the folder containing raw images to be processed.') + parser.add_argument('--save_path', type=str, default='inference/image_process', + help='output folder for processed images.') + parser.add_argument('-opt', '--network_options', + default='options/base/network_g/cunet.yaml', + help='ratio map estimator network architecture options.') + parser.add_argument('-optd', '--denoising_network_options', + default='options/base/network_g/41unet.yaml', + help='denoising network architecture options.') + parser.add_argument('--ratio', '--dgain', type=float, default=1.0, + help='maximum exposure gain ratio.') + parser.add_argument('--target_exposure', type=float, + help='target exposure (overrides ratio).') + parser.add_argument('--bps', '--output_bps', type=int, default=8, + help='output bit depth.') + + args = parser.parse_args() + + device = get_available_device() + + network_g = setup_network(args.network_options, args.pretrained_network) + network_gd = setup_network(args.denoising_network_options, args.pretrained_denosing_network) + + raw_paths = sorted(glob.glob(f'{args.data_path}/*')) + ratio = args.ratio + os.makedirs(args.save_path, exist_ok=True) + + for raw_path in tqdm(raw_paths, desc="Processing images"): + start_time = time.time() + + if args.target_exposure is not None: + iso, exp_time = metainfo(raw_path) + ratio = args.target_exposure / (iso * exp_time) + + raw, raw_pattern, im, bl, wl = read_img(raw_path) + im0 = (im - bl) / (wl - bl) + im_normalized = im0 * ratio + + im_normalized = im_normalized.to(device) + result = network_g(im_normalized) + result = filter_bilateral(result, 15, torch.tensor(15.0).cuda(), torch.tensor(1.0).cuda()) + result = result.cpu() + + ratiomap_output = result + realmap = torch.tensor(ratio / ratiomap_output).to(device) + result = im0 * ratio / ratiomap_output + + result = result.clip(0.0, 1.0).to(device) + + ratio1 = realmap.mean().item() + result1 = network_gd(result, realmap, if_train=False) + result1 = result1.cpu().clip(0, 1) + + rgb = postprocess(raw, raw_pattern, result1, bl, wl, args.bps) + rgb_tensor = torch.FloatTensor(rgb / 1.0) + rgb_final = rgb_tensor.cpu().numpy().astype(np.uint8) + + base_save_path = raw_path.replace(args.data_path, args.save_path) + cv2.imwrite(f'{base_save_path}.png', rgb_final) + + raw.close() + + +if __name__ == '__main__': + image_process() \ No newline at end of file diff --git a/scripts/test_metrics_ultraled.py b/scripts/test_metrics_ultraled.py new file mode 100644 index 0000000..89bd4d1 --- /dev/null +++ b/scripts/test_metrics_ultraled.py @@ -0,0 +1,129 @@ +import os +import glob +import numpy as np +import cv2 +import torch +from torch import nn +import pyiqa +from tqdm import tqdm +import re +import sys + +from ultraled.data.raw_utils import * + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + +def match_image_pairs(folder_path_A, folder_path_B): + pairs = [] + + for fname in os.listdir(folder_path_A): + if not fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): + continue + + match = re.search(r'_(\d+)\.', fname) + if (match and int(match.group(1)) == 10) or (match and int(match.group(1)) == 7): + continue + + try: + prefix = fname.split('_')[0] + scene_num = int(prefix) + except: + continue + + target_fname = f"scene{scene_num}hdr.png" + target_path = os.path.join(folder_path_B, target_fname) + + if os.path.exists(target_path): + src_path = os.path.join(folder_path_A, fname) + pairs.append((src_path, target_path)) + + return pairs + +def calculate_metrics_with_correction(folder_A, folder_B): + psnr_metric = pyiqa.create_metric('psnr', device=device) + ssim_metric = pyiqa.create_metric('ssim', device=device) + lpips_metric = pyiqa.create_metric('lpips', device=device) + + image_pairs = match_image_pairs(folder_A, folder_B) + + if not image_pairs: + print("No valid image pairs found") + return + + total_psnr = 0.0 + total_ssim = 0.0 + total_lpips = 0.0 + valid_pairs = 0 + + pbar = tqdm(image_pairs, desc="Processing") + for a_path, b_path in pbar: + try: + img_a = load_image(a_path) + img_b = load_image(b_path) + + img_a = img_a / 255.0 + img_b = img_b / 255.0 + + if img_a.shape[:2] != img_b.shape[:2]: + img_b = resize_image(img_b, target_shape=img_a.shape[:2], is_mask=False) + + tensor_a = image_to_tensor(img_a).clip(0, 1).to(device) + tensor_b = image_to_tensor(img_b).clip(0, 1).to(device) + + corrected_a = illuminance_correct(tensor_a, tensor_b) + corrected_a = corrected_a.clamp(0, 1) + + with torch.no_grad(): + psnr_score = psnr_metric(corrected_a, tensor_b) + ssim_score = ssim_metric(corrected_a, tensor_b) + lpips_score = lpips_metric(corrected_a, tensor_b) + + total_psnr += psnr_score.item() + total_ssim += ssim_score.item() + total_lpips += lpips_score.item() + valid_pairs += 1 + + current_metrics = { + 'PSNR': f"{psnr_score.item():.4f}", + 'SSIM': f"{ssim_score.item():.4f}", + 'LPIPS': f"{lpips_score.item():.4f}" + } + + pbar.set_postfix(current_metrics) + + except Exception as e: + pbar.write(f"Failed {os.path.basename(a_path)}: {str(e)}") + continue + + if valid_pairs > 0: + avg_psnr = total_psnr / valid_pairs + avg_ssim = total_ssim / valid_pairs + avg_lpips = total_lpips / valid_pairs + + print(f" PSNR: {avg_psnr:.4f}", f" SSIM: {avg_ssim:.4f}", f" LPIPS: {avg_lpips:.4f}") + + return { + 'psnr': avg_psnr, + 'ssim': avg_ssim, + 'lpips': avg_lpips, + 'valid_pairs': valid_pairs + } + else: + print("No valid image pairs for evaluation") + return None + +if __name__ == "__main__": + # Place test data in folder A and ground truth in folder B. When evaluating metrics, do not alter the naming format of source files. + results = calculate_metrics_with_correction( + folder_A="../../../defaultShare/archive/mengyuang/NeurIPS/final_test/ratio50_final_test", + folder_B="../../../defaultShare/archive/mengyuang/SonyA7M4data_latest/groundtruth" + ) + results = calculate_metrics_with_correction( + folder_A="../../../defaultShare/archive/mengyuang/NeurIPS/final_test/ratio100_final_test", + folder_B="../../../defaultShare/archive/mengyuang/SonyA7M4data_latest/groundtruth" + ) + results = calculate_metrics_with_correction( + folder_A="../../../defaultShare/archive/mengyuang/NeurIPS/final_test/ratio200_final_test", + folder_B="../../../defaultShare/archive/mengyuang/SonyA7M4data_latest/groundtruth" + ) \ No newline at end of file diff --git a/ultraled/__init__.py b/ultraled/__init__.py new file mode 100644 index 0000000..2843754 --- /dev/null +++ b/ultraled/__init__.py @@ -0,0 +1,12 @@ +# https://github.com/xinntao/BasicSR +# flake8: noqa +from .archs import * +from .data import * +from .losses import * +from .metrics import * +from .models import * +from .ops import * +from .test import * +from .train import * +from .utils import * +# from .version import __gitsha__, __version__ diff --git a/ultraled/__pycache__/__init__.cpython-38.pyc b/ultraled/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..4bf5c66 Binary files /dev/null and b/ultraled/__pycache__/__init__.cpython-38.pyc differ diff --git a/ultraled/__pycache__/test.cpython-38.pyc b/ultraled/__pycache__/test.cpython-38.pyc new file mode 100644 index 0000000..f1c9744 Binary files /dev/null and b/ultraled/__pycache__/test.cpython-38.pyc differ diff --git a/ultraled/__pycache__/train.cpython-38.pyc b/ultraled/__pycache__/train.cpython-38.pyc new file mode 100644 index 0000000..61f59f3 Binary files /dev/null and b/ultraled/__pycache__/train.cpython-38.pyc differ diff --git a/ultraled/archs/__init__.py b/ultraled/archs/__init__.py new file mode 100644 index 0000000..ecff4ec --- /dev/null +++ b/ultraled/archs/__init__.py @@ -0,0 +1,25 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from ultraled.utils import get_root_logger, scandir +from ultraled.utils.registry import ARCH_REGISTRY + +__all__ = ['build_network'] + +# automatically scan and import arch modules for registry +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'ultraled.archs.{file_name}') for file_name in arch_filenames] + + +def build_network(opt): + opt = deepcopy(opt) + network_type = opt.pop('type') + net = ARCH_REGISTRY.get(network_type)(**opt) + logger = get_root_logger() + logger.info(f'Network [{net.__class__.__name__}] is created.') + return net diff --git a/ultraled/archs/__pycache__/__init__.cpython-38.pyc b/ultraled/archs/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..2c111cc Binary files /dev/null and b/ultraled/archs/__pycache__/__init__.cpython-38.pyc differ diff --git a/ultraled/archs/__pycache__/cunet_arch.cpython-38.pyc b/ultraled/archs/__pycache__/cunet_arch.cpython-38.pyc new file mode 100644 index 0000000..df3d63d Binary files /dev/null and b/ultraled/archs/__pycache__/cunet_arch.cpython-38.pyc differ diff --git a/ultraled/archs/__pycache__/norm_util.cpython-38.pyc b/ultraled/archs/__pycache__/norm_util.cpython-38.pyc new file mode 100644 index 0000000..a06fb8a Binary files /dev/null and b/ultraled/archs/__pycache__/norm_util.cpython-38.pyc differ diff --git a/ultraled/archs/__pycache__/unet_arch.cpython-38.pyc b/ultraled/archs/__pycache__/unet_arch.cpython-38.pyc new file mode 100644 index 0000000..943b4a3 Binary files /dev/null and b/ultraled/archs/__pycache__/unet_arch.cpython-38.pyc differ diff --git a/ultraled/archs/arch_util.py b/ultraled/archs/arch_util.py new file mode 100644 index 0000000..418983a --- /dev/null +++ b/ultraled/archs/arch_util.py @@ -0,0 +1,318 @@ +import collections.abc +import math +import torch +import torchvision +import warnings +from distutils.version import LooseVersion +from itertools import repeat +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + +from ultraled.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv +from ultraled.utils import get_root_logger + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +class DCNv2Pack(ModulatedDeformConvPack): + """Modulated deformable conv for deformable alignment. + + Different from the official DCNv2Pack, which generates offsets and masks + from the preceding features, this DCNv2Pack takes another different + features to generate offsets and masks. + + Ref: + Delving Deep into Deformable Alignment in Video Super-Resolution. + """ + + def forward(self, x, feat): + out = self.conv_offset(feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_absmean = torch.mean(torch.abs(offset)) + if offset_absmean > 50: + logger = get_root_logger() + logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.') + + if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'): + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, + self.dilation, mask) + else: + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + up = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [low, up], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * up - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple diff --git a/ultraled/archs/cunet_arch.py b/ultraled/archs/cunet_arch.py new file mode 100644 index 0000000..9f9169a --- /dev/null +++ b/ultraled/archs/cunet_arch.py @@ -0,0 +1,212 @@ +from ultraled.archs.norm_util import MultipleScaleNorm2d, ScaleNorm2d +from ultraled.utils.registry import ARCH_REGISTRY + +import torch +from torch import nn + +### Normalization +from torch.nn import Identity +from torch.nn import BatchNorm2d, InstanceNorm2d +from ultraled.archs.norm_util import LayerNorm2d + +import torch +from torch import nn +from torch.nn import functional as F +import math + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), + nn.LeakyReLU(negative_slope=0.2, inplace=True) + ) + + +class ResidualBlock(nn.Module): + def __init__(self, channels) -> None: + super().__init__() + + self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x, y): + out1 = self.conv1(x) + out11 = self.lrelu(out1) + out2 = self.conv2(out11) + out21 = out2 + y + + return out21 + + def lrelu(self, x): + outt = torch.max(0.2 * x, x) + return outt + + +@ARCH_REGISTRY.register() +class CUNetArch(nn.Module): + def __init__(self, inchannels=3, outchannels=3, channels=32) -> None: + super().__init__() + + self.conv0_1 = nn.Linear(in_features=300, out_features=256) + self.conv0_2 = nn.Linear(in_features=256, out_features=128) + + self.conv1_1 = nn.Conv2d(inchannels, channels, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + self.pool1 = nn.MaxPool2d(kernel_size=2) + + self.conv2_1 = nn.Conv2d(channels, channels * 2, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(channels * 2, channels * 2, kernel_size=3, stride=1, padding=1) + self.pool2 = nn.MaxPool2d(kernel_size=2) + + + self.conv3_1 = nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1) + self.conv3_2 = ResidualBlock(channels * 4) + self.pool3 = nn.MaxPool2d(kernel_size=2) + + self.conv4_1 = nn.Conv2d(channels * 4, channels * 8, kernel_size=3, stride=1, padding=1) + self.conv4_2 = ResidualBlock(channels * 8) + self.pool4 = nn.MaxPool2d(kernel_size=2) + + self.conv5_1 = nn.Conv2d(channels * 8, channels * 16, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(channels * 16, channels * 16, kernel_size=3, stride=1, padding=1) + + + self.upv6 = nn.ConvTranspose2d(channels * 16, channels * 8, 2, stride=2) + self.conv6_1 = nn.Conv2d(channels * 16, channels * 8, kernel_size=3, stride=1, padding=1) + self.conv6_2 = ResidualBlock(channels * 8) + + + self.upv7 = nn.ConvTranspose2d(channels * 8, channels * 4, 2, stride=2) + self.conv7_1 = nn.Conv2d(channels * 8, channels * 4, kernel_size=3, stride=1, padding=1) + self.conv7_2 = ResidualBlock(channels * 4) + + self.upv8 = nn.ConvTranspose2d(channels * 4, channels * 2, 2, stride=2) + self.conv8_1 = nn.Conv2d(channels * 4, channels * 2, kernel_size=3, stride=1, padding=1) + self.conv8_2 = nn.Conv2d(channels * 2, channels * 2, kernel_size=3, stride=1, padding=1) + + self.upv9 = nn.ConvTranspose2d(channels * 2, channels, 2, stride=2) + self.conv9_1 = nn.Conv2d(channels * 2, channels, kernel_size=3, stride=1, padding=1) + self.conv9_2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + + self.conv10_1 = nn.Conv2d(channels, outchannels, kernel_size=1, stride=1) + + def _check_and_padding(self, x): + + _, _, h, w = x.size() + stride = (2 ** (5 - 1)) + + dh = -h % stride + dw = -w % stride + + top_pad = dh // 2 + bottom_pad = dh - top_pad + left_pad = dw // 2 + right_pad = dw - left_pad + self.crop_indices = (left_pad, w+left_pad, top_pad, h+top_pad) + + padded_tensor = F.pad( + x, (left_pad, right_pad, top_pad, bottom_pad), mode="reflect" + ) + + return padded_tensor + + def _check_and_crop(self, x): + left, right, top, bottom = self.crop_indices + x = x[:, :, top:bottom, left:right] + return x + + def ratio_map_encoding(self, y): + sigma = 30 + r = torch.arange(0, 300).cuda() + + + Hr, Wr = 128, 128 + r = r.view(1, 300, 1, 1).expand(-1, -1, Hr, Wr) + y = F.interpolate(y, size=(Hr, Wr), mode='bilinear', align_corners=False) + r = torch.exp(-((r - y) ** 2) / (2 * sigma * sigma)) / (math.sqrt(2 * math.pi) * sigma) + r = torch.mul(r, 1 / y) + + return r + + + def forward(self, x, y, if_train=True): + r = self.ratio_map_encoding(y) + + if if_train: + Hc4, Wc4 = int(x.size(2) / 4), int(x.size(3) / 4) + else: + Hc4, Wc4 = int(x.size(2) / 4) + 2, int(x.size(3) / 4) + r = F.interpolate(r, size=(Hc4, Wc4), mode='bilinear', align_corners=False) + batch_size, _, H, W = r.shape + r = r.view(batch_size, -1, H * W).permute(0, 2, 1) + + + # MLP + control = self.conv0_1(r) + control = self.conv0_2(control) + control = control.permute(0, 2, 1).view(batch_size, 128, Hc4, Wc4) + + x = self._check_and_padding(x) + + + conv1 = self.lrelu(self.conv1_1(x)) + conv1 = self.lrelu(self.conv1_2(conv1)) + pool1 = self.pool1(conv1) + + conv2 = self.lrelu(self.conv2_1(pool1)) + conv2 = self.lrelu(self.conv2_2(conv2)) + pool2 = self.pool1(conv2) + + conv3 = self.lrelu(self.conv3_1(pool2)) + conv3_1 = conv3 + control + conv3 = self.lrelu(self.conv3_2(conv3, conv3_1)) + pool3 = self.pool1(conv3) + + conv4 = self.lrelu(self.conv4_1(pool3)) + conv4 = self.lrelu(self.conv4_2(conv4, conv4)) + pool4 = self.pool1(conv4) + + conv5 = self.lrelu(self.conv5_1(pool4)) + conv5 = self.lrelu(self.conv5_2(conv5)) + + up6 = self.upv6(conv5) + up6 = torch.cat([up6, conv4], 1) + conv6 = self.lrelu(self.conv6_1(up6)) + conv6 = self.lrelu(self.conv6_2(conv6, conv6)) + + up7 = self.upv7(conv6) + up7 = torch.cat([up7, conv3_1], 1) + conv7 = self.lrelu(self.conv7_1(up7)) + conv7 = self.lrelu(self.conv7_2(conv7, conv7)) + + up8 = self.upv8(conv7) + up8 = torch.cat([up8, conv2], 1) + conv8 = self.lrelu(self.conv8_1(up8)) + conv8 = self.lrelu(self.conv8_2(conv8)) + + up9 = self.upv9(conv8) + up9 = torch.cat([up9, conv1], 1) + conv9 = self.lrelu(self.conv9_1(up9)) + conv9 = self.lrelu(self.conv9_2(conv9)) + + conv10 = self.conv10_1(conv9) + out = self._check_and_crop(conv10) + return out + + + + def lrelu(self, x): + outt = torch.max(0.2 * x, x) + return outt \ No newline at end of file diff --git a/ultraled/archs/norm_util.py b/ultraled/archs/norm_util.py new file mode 100644 index 0000000..ff5b655 --- /dev/null +++ b/ultraled/archs/norm_util.py @@ -0,0 +1,59 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class LayerNorm2d(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, num_features, eps=1e-6, affine=True, data_format="channels_first", track_running_stats=False): + super().__init__() + self.weight = nn.Parameter(torch.ones(num_features)) if affine else None + self.bias = nn.Parameter(torch.zeros(num_features)) if affine else None + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.num_features = (num_features, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.num_features, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + if self.weight is not None: + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + +class ScaleNorm2d(nn.Module): + def __init__(self, num_features, bias=True, *, init_weight=1.0, init_bias=0.0) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1) * init_weight, requires_grad=True) + self.bias = nn.Parameter(torch.ones(1, num_features, 1, 1) * init_bias, requires_grad=bias) + + def forward(self, x): + return self.weight * x + self.bias + +class MultipleScaleNorm2d(nn.Module): + def __init__(self, num_features, bias=True, numbers=1, *, init_weight=1.0, init_bias=0.0) -> None: + super().__init__() + self.norms = nn.ModuleList() + for _ in range(numbers): + self.norms.append(ScaleNorm2d( + num_features, + bias=bias, + init_weight=init_weight, + init_bias=init_bias + )) + # self.deploy_norm = ScaleNorm2d(num_features, False) + # self.deploy_norm.weight.requires_grad_(False) + self.numbers = numbers + + def forward(self, x, idx=0): + assert idx < self.numbers + return self.norms[idx](x) diff --git a/ultraled/archs/unet_arch.py b/ultraled/archs/unet_arch.py new file mode 100644 index 0000000..584443a --- /dev/null +++ b/ultraled/archs/unet_arch.py @@ -0,0 +1,126 @@ +from ultraled.utils.registry import ARCH_REGISTRY + +import torch +from torch import nn +from torch.nn import functional as F + +# The same arch as "Learning to see in the dark" (CVPR 2018) +@ARCH_REGISTRY.register() +class UNetArch(nn.Module): + def __init__(self, inchannels=3, outchannels=3, channels=32) -> None: + super().__init__() + + self.conv1_1 = nn.Conv2d(inchannels, channels, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + self.pool1 = nn.MaxPool2d(kernel_size=2) + + self.conv2_1 = nn.Conv2d(channels, channels * 2, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(channels * 2, channels * 2, kernel_size=3, stride=1, padding=1) + self.pool2 = nn.MaxPool2d(kernel_size=2) + + self.conv3_1 = nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(channels * 4, channels * 4, kernel_size=3, stride=1, padding=1) + self.pool3 = nn.MaxPool2d(kernel_size=2) + + self.conv4_1 = nn.Conv2d(channels * 4, channels * 8, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(channels * 8, channels * 8, kernel_size=3, stride=1, padding=1) + self.pool4 = nn.MaxPool2d(kernel_size=2) + + self.conv5_1 = nn.Conv2d(channels * 8, channels * 16, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(channels * 16, channels * 16, kernel_size=3, stride=1, padding=1) + + self.upv6 = nn.ConvTranspose2d(channels * 16, channels * 8, 2, stride=2) + self.conv6_1 = nn.Conv2d(channels * 16, channels * 8, kernel_size=3, stride=1, padding=1) + self.conv6_2 = nn.Conv2d(channels * 8, channels * 8, kernel_size=3, stride=1, padding=1) + + self.upv7 = nn.ConvTranspose2d(channels * 8, channels * 4, 2, stride=2) + self.conv7_1 = nn.Conv2d(channels * 8, channels * 4, kernel_size=3, stride=1, padding=1) + self.conv7_2 = nn.Conv2d(channels * 4, channels * 4, kernel_size=3, stride=1, padding=1) + + self.upv8 = nn.ConvTranspose2d(channels * 4, channels * 2, 2, stride=2) + self.conv8_1 = nn.Conv2d(channels * 4, channels * 2, kernel_size=3, stride=1, padding=1) + self.conv8_2 = nn.Conv2d(channels * 2, channels * 2, kernel_size=3, stride=1, padding=1) + + self.upv9 = nn.ConvTranspose2d(channels * 2, channels, 2, stride=2) + self.conv9_1 = nn.Conv2d(channels * 2, channels, kernel_size=3, stride=1, padding=1) + self.conv9_2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + + self.conv10_1 = nn.Conv2d(channels, outchannels, kernel_size=1, stride=1) + + def _check_and_padding(self, x): + ### This function is totally writen by ChatGPT + # Calculate the required size based on the input size and required factor + _, _, h, w = x.size() + stride = (2 ** (5 - 1)) + + # Calculate the number of pixels needed to reach the required size + dh = -h % stride + dw = -w % stride + + # Calculate the amount of padding needed for each side + top_pad = dh // 2 + bottom_pad = dh - top_pad + left_pad = dw // 2 + right_pad = dw - left_pad + self.crop_indices = (left_pad, w+left_pad, top_pad, h+top_pad) + + # Pad the tensor with reflect mode + padded_tensor = F.pad( + x, (left_pad, right_pad, top_pad, bottom_pad), mode="reflect" + ) + + return padded_tensor + + def _check_and_crop(self, x): + left, right, top, bottom = self.crop_indices + x = x[:, :, top:bottom, left:right] + return x + + def forward(self, x): + x = self._check_and_padding(x) + conv1 = self.lrelu(self.conv1_1(x)) + conv1 = self.lrelu(self.conv1_2(conv1)) + pool1 = self.pool1(conv1) + + conv2 = self.lrelu(self.conv2_1(pool1)) + conv2 = self.lrelu(self.conv2_2(conv2)) + pool2 = self.pool1(conv2) + + conv3 = self.lrelu(self.conv3_1(pool2)) + conv3 = self.lrelu(self.conv3_2(conv3)) + pool3 = self.pool1(conv3) + + conv4 = self.lrelu(self.conv4_1(pool3)) + conv4 = self.lrelu(self.conv4_2(conv4)) + pool4 = self.pool1(conv4) + + conv5 = self.lrelu(self.conv5_1(pool4)) + conv5 = self.lrelu(self.conv5_2(conv5)) + + up6 = self.upv6(conv5) + up6 = torch.cat([up6, conv4], 1) + conv6 = self.lrelu(self.conv6_1(up6)) + conv6 = self.lrelu(self.conv6_2(conv6)) + + up7 = self.upv7(conv6) + up7 = torch.cat([up7, conv3], 1) + conv7 = self.lrelu(self.conv7_1(up7)) + conv7 = self.lrelu(self.conv7_2(conv7)) + + up8 = self.upv8(conv7) + up8 = torch.cat([up8, conv2], 1) + conv8 = self.lrelu(self.conv8_1(up8)) + conv8 = self.lrelu(self.conv8_2(conv8)) + + up9 = self.upv9(conv8) + up9 = torch.cat([up9, conv1], 1) + conv9 = self.lrelu(self.conv9_1(up9)) + conv9 = self.lrelu(self.conv9_2(conv9)) + + conv10 = self.conv10_1(conv9) + out = self._check_and_crop(conv10) + return out + + def lrelu(self, x): + outt = torch.max(0.2 * x, x) + return outt \ No newline at end of file diff --git a/ultraled/batch_test.py b/ultraled/batch_test.py new file mode 100644 index 0000000..88dc135 --- /dev/null +++ b/ultraled/batch_test.py @@ -0,0 +1,67 @@ +import logging +import torch +from os import path as osp +from glob import glob + +from ultraled.data import build_dataloader, build_dataset +from ultraled.models import build_model +from ultraled.train import init_tb_loggers +from ultraled.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs +from ultraled.utils.options import dict2str, parse_options + + +def test_pipeline(root_path): + # parse options, set distributed setting, set ramdom seed + opt, _ = parse_options(root_path, is_train=False) + opt['root_path'] = root_path + experiments_root = osp.join(root_path, 'results', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, 'visualization') + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # mkdir and initialize loggers + make_exp_dirs(opt) + log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # create test dataset and dataloader + test_loaders = [] + for _, dataset_opt in sorted(opt['datasets'].items()): + test_set = build_dataset(dataset_opt) + test_loader = build_dataloader( + test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") + test_loaders.append(test_loader) + + paths = glob(f"{opt['path']['pretrain_network_g_dir']}/*.pth") + if f"{opt['path']['pretrain_network_g_dir']}/net_g_latest.pth" in paths: + paths.pop(paths.index(f"{opt['path']['pretrain_network_g_dir']}/net_g_latest.pth")) + paths = list(sorted(paths, key=lambda x: int(x[:-4].split('_')[-1]))) + opt['path']['pretrain_network_g_dir'] = None + # create model + model = build_model(opt) + + # initialize wandb and tb loggers + tb_logger = init_tb_loggers(opt) + for load_path in paths: + if load_path.endswith('net_g_latest.pth'): + continue + param_key = opt['path'].get('param_key_g', 'params') + model.load_network(model.net_g, load_path, opt['path'].get('strict_load_g', True), param_key) + current_iter = int(load_path[:-4].split('_')[-1]) + for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info(f'Testing {test_set_name}...') + model.validation(test_loader, current_iter=current_iter, tb_logger=tb_logger, save_img=opt['val']['save_img']) + if tb_logger: + tb_logger.close() + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + test_pipeline(root_path) diff --git a/ultraled/data/__init__.py b/ultraled/data/__init__.py new file mode 100644 index 0000000..9ae7c46 --- /dev/null +++ b/ultraled/data/__init__.py @@ -0,0 +1,108 @@ +import importlib +import numpy as np +import random +import torch +import torch.utils.data +from copy import deepcopy +from functools import partial +from os import path as osp + +from ultraled.data.prefetch_dataloader import PrefetchDataLoader +from ultraled.utils import get_root_logger, scandir +from ultraled.utils.dist_util import get_dist_info +from ultraled.utils.registry import DATASET_REGISTRY +import ultraled.data.collate_fn as basicsr_collate + +__all__ = ['build_dataset', 'build_dataloader'] + +# automatically scan and import dataset modules for registry +# scan all the files under the data folder with '_dataset' in file names +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'ultraled.data.{file_name}') for file_name in dataset_filenames] + + +def build_dataset(dataset_opt): + """Build dataset from options. + + Args: + dataset_opt (dict): Configuration for dataset. It must contain: + name (str): Dataset name. + type (str): Dataset type. + """ + dataset_opt = deepcopy(dataset_opt) + dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) + logger = get_root_logger() + logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') + return dataset + + +def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): + """Build dataloader. + + Args: + dataset (torch.utils.data.Dataset): Dataset. + dataset_opt (dict): Dataset options. It contains the following keys: + phase (str): 'train' or 'val'. + num_worker_per_gpu (int): Number of workers for each GPU. + batch_size_per_gpu (int): Training batch size for each GPU. + num_gpu (int): Number of GPUs. Used only in the train phase. + Default: 1. + dist (bool): Whether in distributed training. Used only in the train + phase. Default: False. + sampler (torch.utils.data.sampler): Data sampler. Default: None. + seed (int | None): Seed. Default: None + """ + phase = dataset_opt['phase'] + rank, _ = get_dist_info() + if phase == 'train': + if dist: # distributed training + batch_size = dataset_opt['batch_size_per_gpu'] + num_workers = dataset_opt['num_worker_per_gpu'] + else: # non-distributed training + multiplier = 1 if num_gpu == 0 else num_gpu + batch_size = dataset_opt['batch_size_per_gpu'] * multiplier + num_workers = dataset_opt['num_worker_per_gpu'] * multiplier + dataloader_args = dict( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + drop_last=True) + if sampler is None: + dataloader_args['shuffle'] = True + dataloader_args['worker_init_fn'] = partial( + worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + elif phase in ['val', 'test']: # validation + batch_size = dataset_opt.get('batch_size_per_gpu', 1) + num_workers = dataset_opt.get('num_worker_per_gpu', 0) + dataloader_args = dict(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + else: + raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") + + dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) + dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) + collate_fn_class = dataset_opt.get('collate_fn', None) + if collate_fn_class is not None: + assert hasattr(basicsr_collate, collate_fn_class) + dataloader_args['collate_fn'] = getattr(basicsr_collate, collate_fn_class)() + + prefetch_mode = dataset_opt.get('prefetch_mode') + if prefetch_mode == 'cpu': # CPUPrefetcher + num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) + logger = get_root_logger() + logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') + return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) + else: + # prefetch_mode=None: Normal dataloader + # prefetch_mode='cuda': dataloader for CUDAPrefetcher + return torch.utils.data.DataLoader(**dataloader_args) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # Set the worker seed to num_workers * rank + worker_id + seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/ultraled/data/__pycache__/__init__.cpython-38.pyc b/ultraled/data/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..fc94bee Binary files /dev/null and b/ultraled/data/__pycache__/__init__.cpython-38.pyc differ diff --git a/ultraled/data/__pycache__/collate_fn.cpython-38.pyc b/ultraled/data/__pycache__/collate_fn.cpython-38.pyc new file mode 100644 index 0000000..b0c3e30 Binary files /dev/null and b/ultraled/data/__pycache__/collate_fn.cpython-38.pyc differ diff --git a/ultraled/data/__pycache__/commom_noise_util.cpython-38.pyc b/ultraled/data/__pycache__/commom_noise_util.cpython-38.pyc new file mode 100644 index 0000000..8f5d4d3 Binary files /dev/null and b/ultraled/data/__pycache__/commom_noise_util.cpython-38.pyc differ diff --git a/ultraled/data/__pycache__/data_sampler.cpython-38.pyc b/ultraled/data/__pycache__/data_sampler.cpython-38.pyc new file mode 100644 index 0000000..4b82a24 Binary files /dev/null and b/ultraled/data/__pycache__/data_sampler.cpython-38.pyc differ diff --git a/ultraled/data/__pycache__/hdr_paired_noiseraw_dataset.cpython-38.pyc b/ultraled/data/__pycache__/hdr_paired_noiseraw_dataset.cpython-38.pyc new file mode 100644 index 0000000..0efc6c0 Binary files /dev/null and b/ultraled/data/__pycache__/hdr_paired_noiseraw_dataset.cpython-38.pyc differ diff --git a/ultraled/data/__pycache__/hdr_util.cpython-38.pyc b/ultraled/data/__pycache__/hdr_util.cpython-38.pyc new file mode 100644 index 0000000..9a41724 Binary files /dev/null and b/ultraled/data/__pycache__/hdr_util.cpython-38.pyc differ diff --git a/ultraled/data/__pycache__/noise_util_rawhdr.cpython-38.pyc b/ultraled/data/__pycache__/noise_util_rawhdr.cpython-38.pyc new file mode 100644 index 0000000..ff10642 Binary files /dev/null and b/ultraled/data/__pycache__/noise_util_rawhdr.cpython-38.pyc differ diff --git a/ultraled/data/__pycache__/part_enhance.cpython-38.pyc b/ultraled/data/__pycache__/part_enhance.cpython-38.pyc new file mode 100644 index 0000000..4f1db72 Binary files /dev/null and b/ultraled/data/__pycache__/part_enhance.cpython-38.pyc differ diff --git a/ultraled/data/__pycache__/prefetch_dataloader.cpython-38.pyc b/ultraled/data/__pycache__/prefetch_dataloader.cpython-38.pyc new file mode 100644 index 0000000..ca3a084 Binary files /dev/null and b/ultraled/data/__pycache__/prefetch_dataloader.cpython-38.pyc differ diff --git a/ultraled/data/__pycache__/raw_utils.cpython-38.pyc b/ultraled/data/__pycache__/raw_utils.cpython-38.pyc new file mode 100644 index 0000000..ec20000 Binary files /dev/null and b/ultraled/data/__pycache__/raw_utils.cpython-38.pyc differ diff --git a/ultraled/data/collate_fn.py b/ultraled/data/collate_fn.py new file mode 100644 index 0000000..1fde0da --- /dev/null +++ b/ultraled/data/collate_fn.py @@ -0,0 +1,25 @@ +from torch.utils.data._utils.collate import collate, default_collate_fn_map +from copy import deepcopy +import torch + +DEFAUT_COLLATE_FN_MAP = default_collate_fn_map + +def starlight_collate_fn(): + collate_fn_map = deepcopy(DEFAUT_COLLATE_FN_MAP) + + def collate_tensor_fn(batch, *, collate_fn_map): + return torch.cat(batch, 0) + + def collate_list_fn(batch, *, collate_fn_map): + L = [] + for l in batch: + L.extend(l) + return L + + collate_fn_map[torch.Tensor] = collate_tensor_fn + collate_fn_map[list] = collate_list_fn + + def _collate_fn(batch): + return collate(batch, collate_fn_map=collate_fn_map) + + return _collate_fn \ No newline at end of file diff --git a/ultraled/data/commom_noise_util.py b/ultraled/data/commom_noise_util.py new file mode 100644 index 0000000..805d379 --- /dev/null +++ b/ultraled/data/commom_noise_util.py @@ -0,0 +1,316 @@ +from abc import ABC +import math +import torch +import numpy as np +from scipy import stats + +def _unpack_bayer(x): + _, h, w = x.shape + H = h*2 + W = w*2 + out = np.zeros((H, W)) + + out[0:H:2, 0:W:2] = x[0] + out[0:H:2, 1:W:2] = x[1] + out[1:H:2, 1:W:2] = x[2] + out[1:H:2, 0:W:2] = x[3] + return out + +def _pack_bayer(raw): + h, w = raw.shape + out = np.concatenate((raw[np.newaxis, 0:h:2, 0:w:2], + raw[np.newaxis, 0:h:2, 1:w:2], + raw[np.newaxis, 1:h:2, 1:w:2], + raw[np.newaxis, 1:h:2, 0:w:2]), axis=0) + return out + +def _unpack_bayer_torch(x): + _, h, w = x.shape + H = h*2 + W = w*2 + out = torch.zeros((H, W)) + + out[0:H:2, 0:W:2] = x[0] + out[0:H:2, 1:W:2] = x[1] + out[1:H:2, 1:W:2] = x[2] + out[1:H:2, 0:W:2] = x[3] + return out + +# def _pack_bayer_torch(raw): +# h, w = raw.shape +# out = torch.cat((raw[None, 0:h:2, 0:w:2], # R +# raw[None, 0:h:2, 1:w:2], # G1 +# raw[None, 1:h:2, 1:w:2], # B +# raw[None, 1:h:2, 0:w:2] # G2 +# ), dim=0) +# return out + +def _pack_bayer_torch(raw): + h, w = raw.size(-2), raw.size(-1) + out = torch.cat((raw[..., None, 0:h:2, 0:w:2], # R + raw[..., None, 0:h:2, 1:w:2], # G1 + raw[..., None, 1:h:2, 1:w:2], # B + raw[..., None, 1:h:2, 0:w:2] # G2 + ), dim=-3) + return out + +def _pack_batch_bayer_torch(raw): + _, h, w = raw.shape + out = torch.cat((raw[:, None, 0:h:2, 0:w:2], # R + raw[:, None, 0:h:2, 1:w:2], # G1 + raw[:, None, 1:h:2, 1:w:2], # B + raw[:, None, 1:h:2, 0:w:2] # G2 + ), dim=1) + return out + +def torch_shot_noise(x, k): + return torch.poisson(x / k) * k - x + +def torch_gaussian_noise(x, scale, loc=0): + return torch.randn_like(x) * scale + loc + # return torch.zeros_like(x).normal_(loc, scale) + +def torch_turkey_lambda_noise(x, scale, t_lambda=1.4): + def turkey_lambda_ppf(p, t_lambda): + # assert not torch.any(torch.tensor(t_lambda == 0.0)) + return 1 / t_lambda * (p ** t_lambda - (1 - p) ** t_lambda) + + epsilon = 1e-10 + U = torch.rand_like(x) * (1 - 2 * epsilon) + epsilon + Y = turkey_lambda_ppf(U, t_lambda + 1e-8) * scale + + return Y + +def torch_quant_noise(x, q): + return (torch.rand_like(x) - 0.5) * q + +# def torch_row_noise(x, scale, loc=0): +# _, H, W = x.shape +# noise = torch.zeros((H * 2, 1), device=x.device).normal_(loc, scale).repeat((1, W * 2)) +# return _pack_bayer_torch(noise) + +def torch_row_noise(x, scale, loc=0): + if x.dim() == 4: + B, _, H, W = x.shape + noise = (torch.randn((B, H * 2, 1), device=x.device) * scale + loc).repeat((1, 1, W * 2)) + elif x.dim() == 5: + B, T, _, H, W = x.shape + noise = (torch.randn((B, T, H * 2, 1), device=x.device) * scale + loc).repeat((1, 1, 1, W * 2)) + elif x.dim() == 3: + _, H, W = x.shape + noise = torch.zeros((H * 2, 1), device=x.device).normal_(loc, scale).repeat((1, W * 2)) + else: + raise NotImplementedError() + return _pack_bayer_torch(noise) + +def torch_batch_row_noise(x, scale, loc=0): + B, _, H, W = x.shape + noise = (torch.randn((B, H * 2, 1), device=x.device) * scale + loc).repeat((1, 1, W * 2)) + return _pack_batch_bayer_torch(noise) + +def numpy_shot_noise(x, k): + # print(x, k) + return np.random.poisson(x / k).astype(np.float32) * k - x + +def numpy_gaussian_noise(x, scale): + return stats.norm.rvs(scale=scale, size=x.shape).astype(np.float32) + +def numpy_turkey_lambda_noise(x, scale, t_lambda=1.4): + return stats.tukeylambda.rvs(t_lambda, scale=scale, size=[*x.shape]).astype(np.float32) + +def numpy_row_noise(x, scale): + _, H, W = x.shape + noise = np.random.randn(H * 2, 1).astype(np.float32) * scale + noise = np.repeat(noise, W * 2, 1) + return _pack_bayer(noise) + +def numpy_quant_noise(x, q): + return np.random.uniform(low=-0.5*q, high=0.5*q, size=x.shape) + +class Engine(ABC): + @staticmethod + def uniform(min, max, shape=None): + pass + + @staticmethod + def randint(min, max, shape=None): + pass + + @staticmethod + def randn(shape=None): + pass + + @staticmethod + def log(x): + pass + + @staticmethod + def exp(x): + pass + + @staticmethod + def to_engine_type(x): + pass + + @staticmethod + def shot_noise(x, k): + pass + + @staticmethod + def gaussian_noise(x, scale): + pass + + @staticmethod + def turkey_lambda_noise(x, scale, t_lambda): + pass + + @staticmethod + def row_noise(x, scale): + pass + + @staticmethod + def quant_noise(x, q): + pass + +class NumpyEngine(Engine): + @staticmethod + def uniform(min, max, shape=None): + if shape == None: + return np.random.uniform(min, max) + return np.random.uniform(min, max, size=shape) + + @staticmethod + def randint(min, max, shape=None): + if shape == None: + return np.random.randint(min, max) + return np.random.randint(min, max, size=shape) + + @staticmethod + def randn(shape=None): + if shape == None: + return np.random.randn() + return np.random.randn(shape) + + @staticmethod + def log(x): + return np.log(x) + + @staticmethod + def exp(x): + return np.exp(x) + + @staticmethod + def to_engine_type(x): + return np.array(x) + + @staticmethod + def shot_noise(x, k): + return numpy_shot_noise(x, k) + + @staticmethod + def gaussian_noise(x, scale): + return numpy_gaussian_noise(x, scale) + + @staticmethod + def turkey_lambda_noise(x, scale, t_lambda): + return numpy_turkey_lambda_noise(x, scale, t_lambda) + + @staticmethod + def row_noise(x, scale): + return numpy_row_noise(x, scale) + + @staticmethod + def quant_noise(x, q): + return numpy_quant_noise(x, q) + +class TorchEngine(Engine): + @staticmethod + def uniform(min, max, shape=1, device='cpu'): + if shape != 1: + return torch.rand((shape,), device=device) * (max - min) + min + return torch.rand((shape,)).item() * (max - min) + min + + @staticmethod + def randint(min, max, shape=1, device='cpu'): + if shape != 1: + return torch.randint(min, max, (shape,), device=device) + return torch.randint(min, max, (shape,)).item() + + @staticmethod + def randn(shape=1, device='cpu'): + if shape != 1: + return torch.randn((shape,), device=device) + return torch.randn((shape,)).item() + + @staticmethod + def log(x): + return math.log(x) + + @staticmethod + def exp(x): + return math.exp(x) + + @staticmethod + def to_engine_type(x): + return torch.tensor(x) + + @staticmethod + def shot_noise(x, k): + return torch_shot_noise(x, k) + + @staticmethod + def gaussian_noise(x, scale): + return torch_gaussian_noise(x, scale) + + @staticmethod + def turkey_lambda_noise(x, scale, t_lambda): + return torch_turkey_lambda_noise(x, scale, t_lambda) + + @staticmethod + def row_noise(x, scale): + return torch_row_noise(x, scale) + + @staticmethod + def quant_noise(x, q): + return torch_quant_noise(x, q) + + +class TorchBatchEngine(TorchEngine): + def __init__(self, use_hflip=True, use_rot=True) -> None: + super().__init__() + self.use_hflip = use_hflip + self.use_rot = use_rot + + @staticmethod + def uniform(min, max, shape=None, device='cpu'): + if shape is not None: + return torch.rand((shape,), device=device) * (max - min) + min + return torch.rand((1,)).item() * (max - min) + min + + @staticmethod + def turkey_lambda_noise(x, scale, t_lambda): + return torch_turkey_lambda_noise(x, scale, t_lambda) + + @staticmethod + def log(x): + return torch.log(x) + + @staticmethod + def exp(x): + return torch.exp(x) + + @staticmethod + def row_noise(x, scale): + return torch_batch_row_noise(x, scale) + + def augment(self, *datas): + hflip = self.use_hflip and self.randint(0, 2) == 1 + vflip = self.use_rot and self.randint(0, 2) == 1 + rot90 = self.use_rot and self.randint(0, 2) == 1 + if hflip: + datas = [torch.flip(data, (3,)) for data in datas] + if vflip: + datas = [torch.flip(data, (2,)) for data in datas] + if rot90: + datas = [torch.permute(data, (0, 1, 3, 2)) for data in datas] + return datas \ No newline at end of file diff --git a/ultraled/data/data_sampler.py b/ultraled/data/data_sampler.py new file mode 100644 index 0000000..575452d --- /dev/null +++ b/ultraled/data/data_sampler.py @@ -0,0 +1,48 @@ +import math +import torch +from torch.utils.data.sampler import Sampler + + +class EnlargedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + Modified from torch.utils.data.distributed.DistributedSampler + Support enlarging the dataset for iteration-based training, for saving + time when restart the dataloader after each epoch + + Args: + dataset (torch.utils.data.Dataset): Dataset used for sampling. + num_replicas (int | None): Number of processes participating in + the training. It is usually the world_size. + rank (int | None): Rank of the current process within num_replicas. + ratio (int): Enlarging ratio. Default: 1. + """ + + def __init__(self, dataset, num_replicas, rank, ratio=1): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dataset_size = len(self.dataset) + indices = [v % dataset_size for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/ultraled/data/hdr_paired_noiseraw_dataset.py b/ultraled/data/hdr_paired_noiseraw_dataset.py new file mode 100644 index 0000000..9b58c36 --- /dev/null +++ b/ultraled/data/hdr_paired_noiseraw_dataset.py @@ -0,0 +1,520 @@ +import re +import math +from torch.utils import data as data +import numpy as np +import torch +from os import path as osp +from tqdm import tqdm +from ultraled.data.noise_util_rawhdr import NoiseGenerator + + + +from ultraled.utils.registry import DATASET_REGISTRY + + +import torch +import math + +import random +from ultraled.data.part_enhance import * + + +@DATASET_REGISTRY.register() +class EstimatorHDRRAWDataset(data.Dataset): + def __init__(self, opt) -> None: + super().__init__() + self.opt = opt + # print(self.opt) + self.engine = self.opt.get('engine', 'torch') + + self.root_folder = opt['dataroot'] + self.postfix = opt.get('postfix', None) + self.which_meta = opt.get('which_meta', 'gt') + self.zero_clip = 0 if opt.get('zero_clip', True) else None + self.ratio_range = opt.get('ratio_range', (100, 300)) + self.use_patches = opt.get('use_patches', False) + self.load_in_mem = opt.get('load_in_mem', False) + if self.use_patches: + + self.patch_id_max = opt.get('patch_id_max', 8) + self.patch_tplt = opt.get('patch_tplt', '_s{:03}') + + assert self.postfix == 'npz' + + + self.lq_paths, self.gt_paths = [], [] + with open(opt['data_pair_list'], 'r') as data_pair_list: + pairs = data_pair_list.readlines() + for pair in pairs: + lq, gt = pair.split(' ')[:2] + gt = gt.rstrip('\n') + + if not self.use_patches: + self.lq_paths.append(osp.join(self.root_folder, lq)) + self.gt_paths.append(osp.join(self.root_folder, gt)) + + else: + for i in range(1, 1 + self.patch_id_max): + self.lq_paths.append(osp.join(self.root_folder, self.insert_patch_id(lq, i, self.patch_tplt))) + self.gt_paths.append(osp.join(self.root_folder, self.insert_patch_id(gt, i, self.patch_tplt))) + + if self.load_in_mem: + self.lqs = { + '.'.join([data_path, self.postfix]): \ + self.depack_meta('.'.join([data_path, self.postfix]), self.postfix) + for data_path in tqdm(self.lq_paths, desc='load lq metas in mem...') + } + self.gts = { + '.'.join([data_path, self.postfix]): \ + self.depack_meta('.'.join([data_path, self.postfix]), self.postfix) + for data_path in tqdm(self.gt_paths, desc='load gt metas in mem...') + } + + @staticmethod + def insert_patch_id(path, patch_id, tplt='_s{:03}'): + exts = path.split('.') + base = exts.pop(0) + while exts[0] != 'ARW': + base += '.' + exts.pop(0) + base = base + tplt.format(patch_id) + return base + '.' + '.'.join(exts) + + @staticmethod + def depack_meta(meta_path, postfix='npz', to_tensor=True): + if postfix == 'npz': + meta = np.load(meta_path, allow_pickle=True) + black_level = np.ascontiguousarray(meta['black_level'].copy().astype('float32')) + white_level = np.ascontiguousarray(meta['white_level'].copy().astype('float32')) + im = np.ascontiguousarray(meta['im'].copy().astype('float32')) + wb = np.ascontiguousarray(meta['wb'].copy().astype('float32')) + ccm = np.ascontiguousarray(meta['ccm'].copy().astype('float32')) + meta.close() + elif postfix == None: + ## using rawpy + raise NotImplementedError + else: + raise NotImplementedError + + if to_tensor: + im = torch.from_numpy(im).float().contiguous() + black_level = torch.from_numpy(black_level).float().contiguous() + white_level = torch.from_numpy(white_level).float().contiguous() + wb = torch.from_numpy(wb).float().contiguous() + ccm = torch.from_numpy(ccm).float().contiguous() + + return (im - black_level) / (white_level - black_level), \ + wb, ccm + + @staticmethod + def depack_meta_gt(meta_path, postfix='npz', to_tensor=True): + if postfix == 'npz': + meta = np.load(meta_path, allow_pickle=True) + black_level = np.ascontiguousarray(meta['black_level'].copy().astype('float32')) + white_level = np.ascontiguousarray(meta['white_level'].copy().astype('float32')) + im = np.ascontiguousarray(meta['im'].copy().astype('float32')) + wb = np.ascontiguousarray(meta['wb'].copy().astype('float32')) + ccm = np.ascontiguousarray(meta['ccm'].copy().astype('float32')) + meta.close() + elif postfix == None: + ## using rawpy + raise NotImplementedError + else: + raise NotImplementedError + + if to_tensor: + im = torch.from_numpy(im).float().contiguous() + black_level = torch.from_numpy(black_level).float().contiguous() + white_level = torch.from_numpy(white_level).float().contiguous() + wb = torch.from_numpy(wb).float().contiguous() + ccm = torch.from_numpy(ccm).float().contiguous() + + return im, \ + wb, ccm + + + def randint(self, *range): + if self.engine == 'torch': + return torch.randint(*range, size=(1,)).item() + else: + return np.random.randint(*range) + + def flip(self, x, dim): + if self.engine == 'torch': + return torch.flip(x, (dim,)) + else: + return np.flip(x, dim) + + def transpose(self, x): + if self.engine == 'torch': + return torch.permute(x, (0, 2, 1)) + else: + return np.transpose(x, (0, 2, 1)) + + def __getitem__(self, index): + + def sum_img_and_noise(img, noises): + for noise in noises: + img += noise + return img + + + lq_path = self.lq_paths[index] + gt_path = self.gt_paths[index] + ratio = self.randint(*self.ratio_range) + + if self.postfix is not None: + lq_path = '.'.join([lq_path, self.postfix]) + gt_path = '.'.join([gt_path, self.postfix]) + + + if not self.load_in_mem: + lq_im, lq_wb, lq_ccm = self.depack_meta(lq_path, self.postfix) + gt_im, gt_wb, gt_ccm = self.depack_meta_gt(gt_path, self.postfix) + else: + lq_im, lq_wb, lq_ccm = self.lqs[lq_path] + gt_im, gt_wb, gt_ccm = self.gts[gt_path] + + + ## crop + if self.opt['crop_size'] is not None: + _, H, W = lq_im.shape + crop_size = self.opt['crop_size'] + assert crop_size <= H and crop_size <= W + if self.opt['phase'] == 'train': + h_start = torch.randint(0, H - crop_size, (1,)).item() + w_start = torch.randint(0, W - crop_size, (1,)).item() + else: + # center crop + h_start = (H - crop_size) // 2 + w_start = (W - crop_size) // 2 + lq_im_patch = lq_im[:, h_start:h_start+crop_size, w_start:w_start+crop_size] + gt_im_patch = gt_im[:, h_start:h_start+crop_size, w_start:w_start+crop_size] + else: + lq_im_patch = lq_im + gt_im_patch = gt_im + ## flip + rotate + if self.opt['phase'] == 'train': + hflip = self.opt['use_hflip'] and torch.rand((1,)).item() < 0.5 + vflip = self.opt['use_rot'] and torch.rand((1,)).item() < 0.5 + rot90 = self.opt['use_rot'] and torch.rand((1,)).item() < 0.5 + if hflip: + lq_im_patch = torch.flip(lq_im_patch, (2,)) + gt_im_patch = torch.flip(gt_im_patch, (2,)) + if vflip: + lq_im_patch = torch.flip(lq_im_patch, (1,)) + gt_im_patch = torch.flip(gt_im_patch, (1,)) + if rot90: + lq_im_patch = torch.permute(lq_im_patch, (0, 2, 1)) + gt_im_patch = torch.permute(gt_im_patch, (0, 2, 1)) + + if self.opt.get('ratio_aug') is not None: + ratio_range = self.opt['ratio_aug'] + rand_ratio = torch.rand((1,)).item() * (ratio_range[1] - ratio_range[0]) + ratio_range[0] + + gt_im_patch = gt_im_patch / ratio * rand_ratio + ratio = rand_ratio + + + + im_ratio = gt_im_patch.mean() // lq_im_patch.mean() + lq_im_patch = gt_im_patch / im_ratio + ratio_all = im_ratio * ratio + _, H, W = gt_im_patch.shape + + Highlight_Generator = EstimatorHighlightGenerator(ratio_all) + mask, addmap, lq_im_patch1 = Highlight_Generator.generate_highlight(lq_im_patch) + mask = mask.unsqueeze(0) + + lq_im_patch = lq_im_patch.to(gt_im_patch.device) + lq_im_patch = lq_im_patch * addmap + + map_ratio = (lq_im_patch * ratio_all ) / (gt_im_patch * ratio + 1e-8) + ratiomap_output = torch.mean(map_ratio, dim=0, keepdim=True).unsqueeze(0) * addmap + + lq_ev0_non_zero = lq_im_patch * im_ratio + exposures = [ + torch.clip(lq_ev0_non_zero, min=0, max=1), + torch.clip(lq_ev0_non_zero / 200, min=0, max=1), + torch.clip(lq_ev0_non_zero / 120, min=0, max=1), + torch.clip(lq_ev0_non_zero / 60, min=0, max=1), + torch.clip(lq_ev0_non_zero / 32, min=0, max=1), + torch.clip(lq_ev0_non_zero / 16, min=0, max=1), + torch.clip(lq_ev0_non_zero / 8, min=0, max=1), + torch.clip(lq_ev0_non_zero / 4, min=0, max=1), + torch.clip(lq_ev0_non_zero / 2, min=0, max=1), + ] + + exposures_gt = torch.stack(exposures) + lq_nonoise = lq_im_patch * im_ratio + gt_im_patch = torch.clip(gt_im_patch, min=0) + gt_im_patch = ratiomap_output.squeeze(0) + + + + + + + return { + 'lq_clean':lq_nonoise, + 'gt': exposures_gt, + 'ratio': torch.tensor(ratio), + 'ratio1': ratio_all, + 'wb': gt_wb if self.which_meta == 'gt' else lq_wb, + 'ccm': gt_ccm if self.which_meta == 'gt' else lq_ccm, + 'lq_path': lq_path, + 'gt_path': gt_path, + 'intact': True + } + + def __len__(self): + return len(self.lq_paths) + + + +@DATASET_REGISTRY.register() +class DenoiserHDRRAWDataset(data.Dataset): + def __init__(self, opt) -> None: + super().__init__() + self.opt = opt + self.engine = self.opt.get('engine', 'torch') + + self.root_folder = opt['dataroot'] + self.postfix = opt.get('postfix', None) + self.which_meta = opt.get('which_meta', 'gt') + self.zero_clip = 0 if opt.get('zero_clip', True) else None + self.ratio_range = opt.get('ratio_range', (100, 300)) + self.use_patches = opt.get('use_patches', False) + self.load_in_mem = opt.get('load_in_mem', False) + if self.use_patches: + self.patch_id_max = opt.get('patch_id_max', 8) + self.patch_tplt = opt.get('patch_tplt', '_s{:03}') + + assert self.postfix == 'npz' + + + self.lq_paths, self.gt_paths = [], [] + with open(opt['data_pair_list'], 'r') as data_pair_list: + pairs = data_pair_list.readlines() + for pair in pairs: + lq, gt = pair.split(' ')[:2] + gt = gt.rstrip('\n') + + if not self.use_patches: + self.lq_paths.append(osp.join(self.root_folder, lq)) + self.gt_paths.append(osp.join(self.root_folder, gt)) + + else: + for i in range(1, 1 + self.patch_id_max): + self.lq_paths.append(osp.join(self.root_folder, self.insert_patch_id(lq, i, self.patch_tplt))) + self.gt_paths.append(osp.join(self.root_folder, self.insert_patch_id(gt, i, self.patch_tplt))) + + if self.load_in_mem: + self.lqs = { + '.'.join([data_path, self.postfix]): \ + self.depack_meta('.'.join([data_path, self.postfix]), self.postfix) + for data_path in tqdm(self.lq_paths, desc='load lq metas in mem...') + } + self.gts = { + '.'.join([data_path, self.postfix]): \ + self.depack_meta('.'.join([data_path, self.postfix]), self.postfix) + for data_path in tqdm(self.gt_paths, desc='load gt metas in mem...') + } + + @staticmethod + def insert_patch_id(path, patch_id, tplt='_s{:03}'): + exts = path.split('.') + base = exts.pop(0) + while exts[0] != 'ARW': + base += '.' + exts.pop(0) + base = base + tplt.format(patch_id) + return base + '.' + '.'.join(exts) + + @staticmethod + def depack_meta(meta_path, postfix='npz', to_tensor=True): + if postfix == 'npz': + meta = np.load(meta_path, allow_pickle=True) + black_level = np.ascontiguousarray(meta['black_level'].copy().astype('float32')) + white_level = np.ascontiguousarray(meta['white_level'].copy().astype('float32')) + im = np.ascontiguousarray(meta['im'].copy().astype('float32')) + wb = np.ascontiguousarray(meta['wb'].copy().astype('float32')) + ccm = np.ascontiguousarray(meta['ccm'].copy().astype('float32')) + meta.close() + elif postfix == None: + raise NotImplementedError + else: + raise NotImplementedError + + if to_tensor: + im = torch.from_numpy(im).float().contiguous() + black_level = torch.from_numpy(black_level).float().contiguous() + white_level = torch.from_numpy(white_level).float().contiguous() + wb = torch.from_numpy(wb).float().contiguous() + ccm = torch.from_numpy(ccm).float().contiguous() + + return (im - black_level) / (white_level - black_level), \ + wb, ccm + + @staticmethod + def depack_meta_gt(meta_path, postfix='npz', to_tensor=True): + if postfix == 'npz': + meta = np.load(meta_path, allow_pickle=True) + black_level = np.ascontiguousarray(meta['black_level'].copy().astype('float32')) + white_level = np.ascontiguousarray(meta['white_level'].copy().astype('float32')) + im = np.ascontiguousarray(meta['im'].copy().astype('float32')) + wb = np.ascontiguousarray(meta['wb'].copy().astype('float32')) + ccm = np.ascontiguousarray(meta['ccm'].copy().astype('float32')) + meta.close() + elif postfix == None: + ## using rawpy + raise NotImplementedError + else: + raise NotImplementedError + + if to_tensor: + im = torch.from_numpy(im).float().contiguous() + black_level = torch.from_numpy(black_level).float().contiguous() + white_level = torch.from_numpy(white_level).float().contiguous() + wb = torch.from_numpy(wb).float().contiguous() + ccm = torch.from_numpy(ccm).float().contiguous() + + return im, \ + wb, ccm + + + def randint(self, *range): + if self.engine == 'torch': + return torch.randint(*range, size=(1,)).item() + else: + return np.random.randint(*range) + + def flip(self, x, dim): + if self.engine == 'torch': + return torch.flip(x, (dim,)) + else: + return np.flip(x, dim) + + def transpose(self, x): + if self.engine == 'torch': + return torch.permute(x, (0, 2, 1)) + else: + return np.transpose(x, (0, 2, 1)) + + def __getitem__(self, index): + + def sum_img_and_noise(img, noises): + for noise in noises: + img += noise + return img + + + lq_path = self.lq_paths[index] + gt_path = self.gt_paths[index] + ratio = self.randint(*self.ratio_range) + + if self.postfix is not None: + lq_path = '.'.join([lq_path, self.postfix]) + gt_path = '.'.join([gt_path, self.postfix]) + + if not self.load_in_mem: + lq_im, lq_wb, lq_ccm = self.depack_meta(lq_path, self.postfix) + gt_im, gt_wb, gt_ccm = self.depack_meta_gt(gt_path, self.postfix) + else: + lq_im, lq_wb, lq_ccm = self.lqs[lq_path] + gt_im, gt_wb, gt_ccm = self.gts[gt_path] + + ## crop + if self.opt['crop_size'] is not None: + _, H, W = lq_im.shape + crop_size = self.opt['crop_size'] + assert crop_size <= H and crop_size <= W + if self.opt['phase'] == 'train': + h_start = torch.randint(0, H - crop_size, (1,)).item() + w_start = torch.randint(0, W - crop_size, (1,)).item() + else: + # center crop + h_start = (H - crop_size) // 2 + w_start = (W - crop_size) // 2 + lq_im_patch = lq_im[:, h_start:h_start+crop_size, w_start:w_start+crop_size] + gt_im_patch = gt_im[:, h_start:h_start+crop_size, w_start:w_start+crop_size] + else: + lq_im_patch = lq_im + gt_im_patch = gt_im + ## flip + rotate + if self.opt['phase'] == 'train': + hflip = self.opt['use_hflip'] and torch.rand((1,)).item() < 0.5 + vflip = self.opt['use_rot'] and torch.rand((1,)).item() < 0.5 + rot90 = self.opt['use_rot'] and torch.rand((1,)).item() < 0.5 + if hflip: + lq_im_patch = torch.flip(lq_im_patch, (2,)) + gt_im_patch = torch.flip(gt_im_patch, (2,)) + if vflip: + lq_im_patch = torch.flip(lq_im_patch, (1,)) + gt_im_patch = torch.flip(gt_im_patch, (1,)) + if rot90: + lq_im_patch = torch.permute(lq_im_patch, (0, 2, 1)) + gt_im_patch = torch.permute(gt_im_patch, (0, 2, 1)) + + if self.opt.get('ratio_aug') is not None: + ratio_range = self.opt['ratio_aug'] + rand_ratio = torch.rand((1,)).item() * (ratio_range[1] - ratio_range[0]) + ratio_range[0] + + gt_im_patch = gt_im_patch / ratio * rand_ratio + ratio = rand_ratio + + + + im_ratio = gt_im_patch.mean() // lq_im_patch.mean() + lq_im_patch = gt_im_patch / im_ratio + ratio_all = im_ratio * ratio + _, H, W = gt_im_patch.shape + + + Highlight_Generator = DenoiserHighlightGenerator(ratio_all) + mask, addmap, lq_im_patch1 = Highlight_Generator.generate_highlight(lq_im_patch) + mask = mask.unsqueeze(0) + + lq_im_patch = lq_im_patch.to(gt_im_patch.device) + lq_im_patch = lq_im_patch * addmap + + map_ratio = (lq_im_patch * ratio_all ) / (gt_im_patch * ratio + 1e-8) + ratiomap_output = torch.mean(map_ratio, dim=0, keepdim=True).unsqueeze(0) * addmap + + + + lq_ev0_non_zero = lq_im_patch * im_ratio + exposures = [ + torch.clip(lq_ev0_non_zero, min=0, max=1), + torch.clip(lq_ev0_non_zero / 200, min=0, max=1), + torch.clip(lq_ev0_non_zero / 120, min=0, max=1), + torch.clip(lq_ev0_non_zero / 60, min=0, max=1), + torch.clip(lq_ev0_non_zero / 32, min=0, max=1), + torch.clip(lq_ev0_non_zero / 16, min=0, max=1), + torch.clip(lq_ev0_non_zero / 8, min=0, max=1), + torch.clip(lq_ev0_non_zero / 4, min=0, max=1), + torch.clip(lq_ev0_non_zero / 2, min=0, max=1), + ] + + exposures_gt = torch.stack(exposures) + lq_nonoise = lq_im_patch * im_ratio + gt_im_patch = torch.clip(gt_im_patch, min=0) + gt_im_patch = ratiomap_output.squeeze(0) + + + + + + + return { + 'lq_clean':lq_nonoise, + 'gt': exposures_gt, + 'ratio': torch.tensor(ratio), + 'ratio1': ratio_all, + 'wb': gt_wb if self.which_meta == 'gt' else lq_wb, + 'ccm': gt_ccm if self.which_meta == 'gt' else lq_ccm, + 'lq_path': lq_path, + 'gt_path': gt_path, + 'intact': True + } + + def __len__(self): + return len(self.lq_paths) \ No newline at end of file diff --git a/ultraled/data/hdr_util.py b/ultraled/data/hdr_util.py new file mode 100644 index 0000000..72758ad --- /dev/null +++ b/ultraled/data/hdr_util.py @@ -0,0 +1,234 @@ +import math +import torch +from torch import nn +from torch.nn.functional import interpolate, pad, conv2d + +class ClipBase(nn.Module): + def __init__(self, clip=False) -> None: + super().__init__() + self._clip_func = torch.clamp if clip else diff_clamp + + def do_clip(self): + self._clip_func = torch.clamp + + def dont_clip(self): + self._clip_func = diff_clamp + + def forward(self, x): + return x + +def _compute_padding(kernel_size): + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def normalize_kernel2d(input): + norm = input.abs().sum(dim=-1).sum(dim=-1) + return input / (norm[..., None, None]) + + +def filter2d(input, kernel, border_type: str = 'reflect'): + # prepare kernel + c = input.shape[-3] + shape = input.shape + tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + padding_shape = _compute_padding([height, width]) + if input.dim() == 5: + padding_shape += [0, 0] + input = pad(input, padding_shape, mode=border_type) + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + out = output.view(*shape) + return out + +def get_laplacian_kernel2d(kernel_size, *, device = None, dtype = torch.float32): + ky, kx = kernel_size + kernel = torch.ones((ky, kx), device=device, dtype=dtype) + mid_x = kx // 2 + mid_y = ky // 2 + kernel[mid_y, mid_x] = 1 - kernel.sum() + return kernel + +def laplacian(input, kernel_size, border_type: str = 'reflect', normalized: bool = True): + kernel = get_laplacian_kernel2d(kernel_size, device=input.device, dtype=input.dtype)[None, ...] + + if normalized: + kernel = normalize_kernel2d(kernel) + + return filter2d(input, kernel, border_type) + + +def rgb_to_grayscale(rgb): + r, g, b = rgb.unbind(dim=-3) + l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(rgb.dtype) + l_img = l_img.unsqueeze(dim=-3) + return l_img + +def get_pyramid_gaussian_kernel(): + return ( + torch.tensor( + [ + [ + [1.0, 4.0, 6.0, 4.0, 1.0], + [4.0, 16.0, 24.0, 16.0, 4.0], + [6.0, 24.0, 36.0, 24.0, 6.0], + [4.0, 16.0, 24.0, 16.0, 4.0], + [1.0, 4.0, 6.0, 4.0, 1.0], + ] + ] + ) + / 256.0 + ) + +def pyrdown(input, border_type: str = 'reflect', align_corners: bool = False, factor: float = 2.0): + kernel = get_pyramid_gaussian_kernel() + channel, height, width = input.shape[-3:] + # blur image + x_blur = filter2d(input, kernel, border_type) + + shape = [int(float(height) / factor), int(float(width) // factor)] + mode = 'bilinear' + if input.dim() == 5: + mode = 'trilinear' + shape = [channel] + shape + + out = interpolate( + x_blur, + size=shape, + mode=mode, + align_corners=align_corners, + ) + return out + +def pyrup(input, shape, border_type: str = 'reflect', align_corners: bool = False): + kernel = get_pyramid_gaussian_kernel() + # upsample tensor + mode = 'bilinear' + if input.dim() == 5: + mode = 'trilinear' + else: + shape = shape[-2:] + x_up = interpolate(input, size=shape, mode=mode, align_corners=align_corners) + + # blurs upsampled tensor + x_blur = filter2d(x_up, kernel, border_type) + return x_blur + +def build_pyramid(input, max_level: int, border_type: str = 'reflect', align_corners: bool = False): + # create empty list and append the original image + pyramid = [] + pyramid.append(input) + + # iterate and downsample + for _ in range(max_level - 1): + img_curr = pyramid[-1] + img_down = pyrdown(img_curr, border_type, align_corners) + pyramid.append(img_down) + + return pyramid + +def build_laplacian_pyramid(input, max_level: int, border_type: str = 'reflect', align_corners: bool = False): + # create gaussian pyramid + gaussian_pyramid = build_pyramid(input, max_level, border_type, align_corners) + laplacian_pyramid = [] + + # iterate and compute difference of adjacent layers in a gaussian pyramid + for i in range(max_level - 1): + img_expand = pyrup(gaussian_pyramid[i + 1], gaussian_pyramid[i].shape[-3:], border_type, align_corners) + laplacian = gaussian_pyramid[i] - img_expand + laplacian_pyramid.append(laplacian) + laplacian_pyramid.append(gaussian_pyramid[-1]) + return laplacian_pyramid + +def diff_clamp(x, _min, _max, k=1e-3): + x = torch.minimum(x - _max, (x - _max) * k) + _max + x = torch.maximum(x - _min, (x - _min) * k) + _min + return x + +def pyramid_collapse(pyramid, depth): + for i in range(depth, 0, -1): + pyramid[i-1] += pyrup(pyramid[i], pyramid[i-1].shape[-3:]) + return pyramid[0] + + +class BlendMertens(nn.Module): + def __init__(self, contrast_weight=1.0, saturation_weight=1.0, exposure_weight=1.0, clip=False) -> None: + super().__init__() + self._clip_func = torch.clamp if clip else diff_clamp + self._contrast_weight = contrast_weight + self._saturation_weight = saturation_weight + self._exposure_weight = exposure_weight + + def do_clip(self): + self._clip_func = torch.clamp + + def dont_clip(self): + self._clip_func = diff_clamp + + def get_weight(self, x): + # contrast + gray_x = rgb_to_grayscale(x) + laplacian_x = laplacian(gray_x, (5, 5)) + c_weight = torch.abs(laplacian_x) + c_weight = c_weight ** self._contrast_weight + + # saturation + s_weight = torch.std(x, -3, keepdim=True) + s_weight = s_weight ** self._saturation_weight + + # exposure + sig = 0.2 + e_weight = torch.exp(-torch.pow(x - 0.5, 2) / (2 * sig * sig)) + r_w, g_w, b_w = torch.chunk(e_weight, 3, -3) + e_weight = r_w * g_w * b_w + e_weight = e_weight ** self._exposure_weight + + return c_weight * s_weight * e_weight + 1e-12 + + def forward(self, *data): + result = 0 + data = torch.stack(data) + weights = self.get_weight(data) + weight_sum = torch.sum(weights, 0, keepdim=True) + weights = weights / weight_sum + + pyramid_depth = min(int(math.log2(512)), int(math.log2(min(data[0].shape[-2:])))) + # pyramids + lps = build_laplacian_pyramid(data, pyramid_depth) + gps = build_pyramid(weights, pyramid_depth) + + # combine pyramids with weights + result_ps = [] + for i in range(pyramid_depth): + r_i = torch.sum(lps[i] * gps[i], 0) + result_ps.append(r_i) + result = self._clip_func(pyramid_collapse(result_ps, pyramid_depth-1), 0, 1) + return result \ No newline at end of file diff --git a/ultraled/data/noise_util_rawhdr.py b/ultraled/data/noise_util_rawhdr.py new file mode 100644 index 0000000..9db67d6 --- /dev/null +++ b/ultraled/data/noise_util_rawhdr.py @@ -0,0 +1,254 @@ +import torch +from torch import nn +import random +from ultraled.data.commom_noise_util import * + + + +class NoiseGenerator(nn.Module): + def __init__(self, camera_params, noise_type, *, engine='torch') -> None: + super().__init__() + self.camera_params = camera_params + self.cameras = list(camera_params.keys()) + print('Current Using Cameras: ', self.cameras) + + if engine == 'numpy': + self.engine = NumpyEngine() + else: + self.engine = TorchEngine() + + self.noise_type = noise_type.lower() + self.read_type = 'TurkeyLambda' if 't' in self.noise_type else \ + ('Gaussian' if 'g' in self.noise_type else None) + + @property + def sample_K(self): + index = self.engine.randint(0, len(self.camera_params)) + self.current_camera = self.cameras[index] + self.current_camera_params = self.camera_params[self.current_camera] + self.current_k_range = [ + self.camera_params[self.current_camera]['Kmin'], + self.camera_params[self.current_camera]['Kmax'] + ] + log_K_max = self.engine.log(self.current_camera_params['Kmax']) + log_K_min = self.engine.log(self.current_camera_params['Kmin']) + log_K = self.engine.uniform(log_K_min, log_K_max) + self.log_K = log_K + return self.engine.exp(log_K) + + + @property + def sample_read_param(self): + slope = self.current_camera_params[self.read_type]['slope'] + bias = self.current_camera_params[self.read_type]['bias'] + sigma = self.current_camera_params[self.read_type]['sigma'] + mu = self.log_K * slope + bias + sample = self.engine.randn() * sigma + mu + return self.engine.exp(sample) + + @property + def sample_turkey_lambda(self): + if self.read_type != 'TurkeyLambda': + return None + index = self.engine.randint(0, len(self.current_camera_params[self.read_type]['lambda'])) + return self.current_camera_params[self.read_type]['lambda'][index] + + @property + def sample_row_param(self): + slope = self.current_camera_params['Row']['slope'] + bias = self.current_camera_params['Row']['bias'] + sigma = self.current_camera_params['Row']['sigma'] + mu = self.log_K * slope + bias + sample = self.engine.randn() * sigma + mu + return self.engine.exp(sample) + + + + @property + def sample_color_bias(self): + count = len(self.current_camera_params['ColorBias']) + i_range = (self.current_k_range[1] - self.current_k_range[0]) / count + index = int((self.engine.exp(self.log_K) - self.current_k_range[0]) // i_range) + index = max(min(index, len(self.current_camera_params['ColorBias']) - 1), 0) + color_bias = self.current_camera_params['ColorBias'][index] + return self.engine.to_engine_type(color_bias).reshape(4, 1, 1) + + + + @torch.no_grad() + # def forward(self, img): + def forward(self, img, *, K=None): + + + if K is not None: + self.sample_K = K + else: + K = self.sample_K + # K = self.sample_K + + noise1 = [] + # possion noise + if 'p' in self.noise_type: + shot_noise = self.engine.shot_noise(img, K) + noise1.append(shot_noise) + # read noise + if 'g' in self.noise_type: + read_noise = self.engine.gaussian_noise(img, self.sample_read_param) + noise1.append(read_noise) + elif 't' in self.noise_type: + read_noise = self.engine.turkey_lambda_noise(img, self.sample_read_param, self.sample_turkey_lambda) + noise1.append(read_noise) + # row noise + if 'r' in self.noise_type: + row_noise = self.engine.row_noise(img, self.sample_row_param) + noise1.append(row_noise) + # quant noise + if 'q' in self.noise_type: + quant_noise = self.engine.quant_noise(img, 1) + noise1.append(quant_noise) + if 'c' in self.noise_type: + noise1.append(self.sample_color_bias.to(img.device)) + + + return img, noise1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + +### Support multiple cameras + +# class NoiseGenerator(nn.Module): +# def __init__(self, camera_params, noise_type, *, engine='torch') -> None: +# super().__init__() +# self.camera_params = camera_params +# self.cameras = list(camera_params.keys()) +# print('Current Using Cameras: ', self.cameras) + +# if engine == 'numpy': +# self.engine = NumpyEngine() +# else: +# self.engine = TorchEngine() + +# self.noise_type = noise_type.lower() +# self.read_type = 'TurkeyLambda' if 't' in self.noise_type else \ +# ('Gaussian' if 'g' in self.noise_type else None) + +# @property +# def sample_K(self): +# index = self.engine.randint(0, len(self.camera_params)) +# self.current_camera = self.cameras[index] +# self.current_camera_params = self.camera_params[self.current_camera] +# self.current_k_range = [ +# self.camera_params[self.current_camera]['Kmin'], +# self.camera_params[self.current_camera]['Kmax'] +# ] +# log_K_max = self.engine.log(self.current_camera_params['Kmax']) +# log_K_min = self.engine.log(self.current_camera_params['Kmin']) +# log_K = self.engine.uniform(log_K_min, log_K_max) +# self.log_K = log_K +# return self.engine.exp(log_K) + +# @sample_K.setter +# def sample_K(self, K): +# assert len(self.camera_params) == 1 +# index = 0 +# self.current_camera = self.cameras[index] +# self.current_camera_params = self.camera_params[self.current_camera] +# self.current_k_range = [ +# self.camera_params[self.current_camera]['Kmin'], +# self.camera_params[self.current_camera]['Kmax'] +# ] +# self.log_K = self.engine.log(K) + +# @property +# def sample_read_param(self): +# slope = self.current_camera_params[self.read_type]['slope'] +# bias = self.current_camera_params[self.read_type]['bias'] +# sigma = self.current_camera_params[self.read_type]['sigma'] +# mu = self.log_K * slope + bias +# sample = self.engine.randn() * sigma + mu +# return self.engine.exp(sample) + +# @property +# def sample_turkey_lambda(self): +# if self.read_type != 'TurkeyLambda': +# return None +# index = self.engine.randint(0, len(self.current_camera_params[self.read_type]['lambda'])) +# return self.current_camera_params[self.read_type]['lambda'][index] + +# @property +# def sample_row_param(self): +# slope = self.current_camera_params['Row']['slope'] +# bias = self.current_camera_params['Row']['bias'] +# sigma = self.current_camera_params['Row']['sigma'] +# mu = self.log_K * slope + bias +# sample = self.engine.randn() * sigma + mu +# return self.engine.exp(sample) + +# @property +# def sample_color_bias(self): +# count = len(self.current_camera_params['ColorBias']) +# i_range = (self.current_k_range[1] - self.current_k_range[0]) / count +# index = int((self.engine.exp(self.log_K) - self.current_k_range[0]) // i_range) +# index = max(min(index, len(self.current_camera_params['ColorBias']) - 1), 0) +# color_bias = self.current_camera_params['ColorBias'][index] +# return self.engine.to_engine_type(color_bias).reshape(4, 1, 1) + +# @torch.no_grad() +# # def forward(self, img): +# def forward(self, img, *, K=None): +# if K is not None: +# self.sample_K = K +# else: +# K = self.sample_K +# # K = self.sample_K + +# noise1 = [] +# # possion noise +# if 'p' in self.noise_type: +# shot_noise = self.engine.shot_noise(img, K) +# noise1.append(shot_noise) +# # read noise +# if 'g' in self.noise_type: +# read_noise = self.engine.gaussian_noise(img, self.sample_read_param) +# noise1.append(read_noise) +# elif 't' in self.noise_type: +# read_noise = self.engine.turkey_lambda_noise(img, self.sample_read_param, self.sample_turkey_lambda) +# noise1.append(read_noise) +# # row noise +# if 'r' in self.noise_type: +# row_noise = self.engine.row_noise(img, self.sample_row_param) +# noise1.append(row_noise) +# # quant noise +# if 'q' in self.noise_type: +# quant_noise = self.engine.quant_noise(img, 1) +# noise1.append(quant_noise) +# if 'c' in self.noise_type: +# noise1.append(self.sample_color_bias) + +# return img, noise1 + diff --git a/ultraled/data/part_enhance.py b/ultraled/data/part_enhance.py new file mode 100644 index 0000000..77ec7cd --- /dev/null +++ b/ultraled/data/part_enhance.py @@ -0,0 +1,206 @@ +import math +import torch +import numpy as np +import random +from scipy.ndimage import binary_erosion, binary_dilation + +class EstimatorHighlightGenerator: + def __init__(self, max_gain=10.0): + self.max_gain = max_gain + + def gaussian_kernel(self, distance, radius, max_gain, boundary_gain): + return max_gain * torch.exp(-(distance ** 2) / (2 * (radius ** 2))) + + def inverse_square_exponential_kernel(self, distance, radius, max_gain, boundary_gain): + epsilon = 1.0 + beta = 1.0 + return (max_gain / ((distance / radius) ** 2 + epsilon)) * torch.exp(-beta * (distance / radius)) + + def random_kernel(self, distance, radius, max_gain, boundary_gain): + kernels = [self.gaussian_kernel, self.inverse_square_exponential_kernel] + kernel = random.choice(kernels) + kernel_mask = kernel(distance, radius, max_gain, boundary_gain) + kernel_mask[kernel_mask < boundary_gain] = boundary_gain + return kernel_mask + + def torch_line(self, mask, x1, y1, x2, y2): + dx = abs(x2 - x1) + dy = abs(y2 - y1) + sx = 1 if x1 < x2 else -1 + sy = 1 if y1 < y2 else -1 + err = dx - dy + + while True: + if 0 <= x1 < mask.shape[1] and 0 <= y1 < mask.shape[0]: + mask[y1, x1] = 1 + if x1 == x2 and y1 == y2: + break + e2 = err * 2 + if e2 > -dy: + err -= dy + x1 += sx + if e2 < dx: + err += dx + y1 += sy + + def generate_highlight(self, tensor): + tensor = tensor.unsqueeze(0) + H, W = tensor.shape[2], tensor.shape[3] + mask = torch.ones((H, W), dtype=torch.uint8) + addmap = torch.zeros((H, W)) + final_image = torch.clone(tensor) + + total_area = 0.0 + max_area = (H * W) / 5.0 + num_regions = torch.randint(1, 11, (1,)).item() + + for _ in range(num_regions): + region_type = torch.randint(0, 2, (1,)).item() + center_x = torch.randint(0, W, (1,)).item() + center_y = torch.randint(0, H, (1,)).item() + + if region_type == 0: + radius = torch.randint(10, 50, (1,)).item() + area = math.pi * (radius ** 2) + else: + num_sides = torch.randint(5, 9, (1,)).item() + radius = torch.randint(10, 50, (1,)).item() + area = num_sides * (radius ** 2) * 0.5 + + if total_area + area > max_area: + continue + + region_gain = torch.rand(1).item() * self.max_gain + region_gain_matrix = torch.ones((H, W)) + + if region_type == 0: + y_grid, x_grid = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij") + distance = ((x_grid - center_x) ** 2 + (y_grid - center_y) ** 2).sqrt() + region_mask = distance <= radius + boundary_gain = min(random.randint(1, 1), self.max_gain) + region_gain_matrix[region_mask] = self.random_kernel( + distance[region_mask], radius, region_gain, boundary_gain + ) + else: + num_sides = max(3, num_sides) + angles = torch.linspace(0, 2 * math.pi, num_sides + 1) + x_offsets = (torch.cos(angles) * radius).int() + y_offsets = (torch.sin(angles) * radius).int() + polygon_points = [(center_x + x_offsets[i], center_y + y_offsets[i]) for i in range(num_sides)] + region_mask = torch.zeros((H, W), dtype=torch.bool) + + for i in range(num_sides): + x1, y1 = polygon_points[i] + x2, y2 = polygon_points[(i + 1) % num_sides] + self.torch_line(region_mask, x1, y1, x2, y2) + + y_grid, x_grid = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij") + distance_from_center = ((x_grid - center_x) ** 2 + (y_grid - center_y) ** 2).sqrt() + boundary_gain = min(random.randint(1, 1), self.max_gain) + region_gain_matrix[region_mask] = self.random_kernel( + distance_from_center[region_mask], radius, region_gain, boundary_gain + ) + + mask[region_mask] = 0 + final_image += region_gain_matrix.unsqueeze(0).unsqueeze(0) * region_mask.float() + region_gain_matrix1 = torch.ones_like(region_gain_matrix) + region_gain_matrix1[region_mask] = region_gain_matrix[region_mask] + addmap = region_gain_matrix1.float() + total_area += area + + final_image = torch.clamp(final_image, 0, 1) + return mask, addmap, final_image + + +class DenoiserHighlightGenerator: + def __init__(self, max_gain=10.0): + self.max_gain = max_gain + + def gaussian_kernel(self, distance, radius, max_gain, boundary_gain): + alpha = random.uniform(0.1, 1.0) + return max_gain * torch.exp(-(distance ** 2) / (alpha * (radius ** 2))) + + def inverse_square_exponential_kernel(self, distance, radius, max_gain, boundary_gain): + epsilon = 1.0 + beta = random.uniform(1.0, math.sqrt(max_gain)) + return (max_gain / ((((distance / radius) ** 2) * beta) + epsilon)) + + def random_kernel(self, distance, radius, max_gain, boundary_gain): + kernels = [self.gaussian_kernel, self.inverse_square_exponential_kernel] + kernel = random.choice(kernels) + kernel_mask = kernel(distance, radius, max_gain, boundary_gain) + kernel_mask[kernel_mask < boundary_gain] = boundary_gain + return kernel_mask + + def torch_line_fill(self, mask, x1, y1, x2, y2): + H, W = mask.shape + x1, x2, y1, y2 = min(x1, x2), min(max(x1, x2), W-1), min(y1, y2), min(max(y1, y2), H-1) + a = random.randint(1, x2-x1+2) + b, c = random.uniform(-1, 1), random.uniform(-1, 1) + + for i in range(x1, x2): + ymin = torch.tensor(min(np.floor(y1 + b * y1 * np.sin(i/a*np.pi)), H-2)).int() + ymax = torch.tensor(min(np.floor(y2 + c * y2 * np.cos(i/a*np.pi)), H-2)).int() + mask[ymin:ymax, i] = 1 + + def generate_highlight(self, tensor): + tensor = tensor.unsqueeze(0) + H, W = tensor.shape[2], tensor.shape[3] + mask = torch.ones((H, W), dtype=torch.uint8) + addmap = torch.zeros((H, W)) + final_image = torch.clone(tensor) + + total_area = 0.0 + max_area = (H * W) / 2.0 + num_regions = torch.randint(1, 11, (1,)).item() + + region_gain_matrix = torch.ones((H, W)) + region_gain_matrix1 = torch.ones_like(region_gain_matrix) + + for _ in range(num_regions): + center_x = torch.randint(0, W, (1,)).item() + center_y = torch.randint(0, H, (1,)).item() + radius = torch.randint(10, 300, (1,)).item() + area = math.pi * (radius ** 2) + + if total_area + area > max_area: + continue + + region_gain = torch.rand(1).item() * self.max_gain + + num_sides = torch.randint(5, 9, (1,)).item() + num_sides = max(3, num_sides) + + x1, x2 = random.randint(0, W), random.randint(0, W) + y1, y2 = random.randint(0, H), random.randint(0, H) + center_x, center_y = random.randint(min(x1, x2), max(x1, x2)), random.randint(min(y1, y2), max(y1, y2)) + + region_mask = torch.zeros((H, W), dtype=torch.bool) + self.torch_line_fill(region_mask, x1, y1, x2, y2) + + region_mask_np = region_mask.detach().cpu().numpy() + for _ in range(random.randint(0, 5)): + if random.choice([True, False]): + region_mask_np = binary_erosion(region_mask_np) + else: + region_mask_np = binary_dilation(region_mask_np) + region_mask = torch.from_numpy(region_mask_np).to(region_mask.device) + + y_grid, x_grid = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij") + distance_from_center = ((x_grid - center_x) ** 2 + (y_grid - center_y) ** 2).sqrt() + + boundary_gain = min(random.randint(1, 1), self.max_gain) + region_gain_matrix[region_mask] = self.random_kernel( + distance_from_center[region_mask], radius, region_gain, boundary_gain + ) + + mask[region_mask] = 0 + final_image += region_gain_matrix.unsqueeze(0).unsqueeze(0) * region_mask.float() + + region_gain_matrix1[region_mask] = region_gain_matrix[region_mask] + addmap = region_gain_matrix1.float() + total_area += area + + final_image = torch.clamp(final_image, 0, 1) + return mask, addmap, final_image + diff --git a/ultraled/data/prefetch_dataloader.py b/ultraled/data/prefetch_dataloader.py new file mode 100644 index 0000000..5088425 --- /dev/null +++ b/ultraled/data/prefetch_dataloader.py @@ -0,0 +1,125 @@ +import queue as Queue +import threading +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Ref: + https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Ref: + https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher(): + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher(): + """CUDA prefetcher. + + Ref: + https://github.com/NVIDIA/apex/issues/304# + + It may consums more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/ultraled/data/raw_utils.py b/ultraled/data/raw_utils.py new file mode 100644 index 0000000..745c3f5 --- /dev/null +++ b/ultraled/data/raw_utils.py @@ -0,0 +1,191 @@ +import numpy as np +import os +import exifread +import argparse +import glob +import time +from copy import deepcopy +import math + + +from torch import nn +import pyiqa +import re + +import cv2 +import rawpy +import torch +import torch.nn.functional as F +from tqdm import tqdm + + + +def read_img(raw_path): + """Read and process raw image.""" + raw = rawpy.imread(raw_path) + raw_vis = raw.raw_image_visible.copy() + raw_pattern = raw.raw_pattern + + # Process black and white levels + black_level = np.array(raw.black_level_per_channel, dtype=np.float32).reshape(1, 4, 1, 1) + white_level = np.array(raw.camera_white_level_per_channel, dtype=np.float32) + + if (white_level == None).any(): + white_level = np.array(raw.white_level, dtype=np.float32) + if white_level.size == 1: + white_level = white_level.repeat(4, 0) + + white_level = white_level.reshape(1, 4, 1, 1) + raw_packed = torch.from_numpy(np.float32(pack_raw_bayer(raw_vis, raw_pattern))[np.newaxis]).contiguous() + black_level = torch.from_numpy(black_level).contiguous() + white_level = torch.from_numpy(white_level).contiguous() + + return raw, raw_pattern, raw_packed, black_level, white_level + + +def postprocess(raw, raw_pattern, im, bl, wl, output_bps = 8): + """Post-process the image to RGB.""" + im = im * (wl - bl) + bl + im = im.numpy()[0] + im = depack_raw_bayer(im, raw_pattern) + + H, W = im.shape + raw.raw_image_visible[:H, :W] = im + rgb = raw.postprocess(use_camera_wb=True, half_size=False, + no_auto_bright=True, output_bps=output_bps) + rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + + return rgb + + +def filter_bilateral(tenIn, intSize, tenSigmas, tenSigmac): + """Bilateral filter implementation.""" + tenSigmas = tenSigmas.view(-1, 1, 1, 1, 1) + tenSigmac = tenSigmac.view(-1, 1, 1, 1, 1) + + # Create coordinate grids + half_size = int(math.floor(0.5 * intSize)) + coords = torch.linspace(-half_size, half_size, intSize, + dtype=tenIn.dtype, device=tenIn.device) + tenHor = coords.view(1, -1).repeat(intSize, 1) + tenVer = coords.view(-1, 1).repeat(1, intSize) + + # Calculate distances + tenDists = (tenHor.square() + tenVer.square()).sqrt().view(1, 1, intSize * intSize, 1, 1) + tenDistc = tenIn.view(tenIn.shape[0], tenIn.shape[1], 1, tenIn.shape[2], tenIn.shape[3]) + + # Apply bilateral filtering + tenOut = F.pad(input=tenIn, pad=[half_size, half_size, half_size, half_size], mode='reflect') + tenOut = F.unfold(input=tenOut, kernel_size=intSize, stride=1, padding=0) + tenOut = tenOut.view(tenIn.shape[0], tenIn.shape[1], intSize * intSize, tenIn.shape[2], tenIn.shape[3]) + + tenWeight = ((-0.5 * tenDists.square() / (tenSigmas.square() + 1e-8)) + + (-0.5 * (tenOut - tenDistc).mean([1], True).square() / (tenSigmac.square() + 1e-8))).exp() + tenWeight = tenWeight / (tenWeight.sum([2], True) + 1e-8) + tenOut = (tenOut * tenWeight).sum([2], False) + + return tenOut + + + +Sony_A7S2_CCM = np.array([[ 1.9712269,-0.6789218, -0.29230508], + [-0.29104823, 1.748401 , -0.45735288], + [ 0.02051281,-0.5380369, 1.5175241 ]], + dtype='float32') + + +def pack_raw_bayer(raw: np.ndarray, raw_pattern: np.ndarray): + #pack Bayer image to 4 channels + R = np.where(raw_pattern==0) + G1 = np.where(raw_pattern==1) + B = np.where(raw_pattern==2) + G2 = np.where(raw_pattern==3) + + raw = raw.astype(np.uint16) + H, W = raw.shape + if H % 2 == 1: + raw = raw[:-1] + if W % 2 == 1: + raw = raw[:, :-1] + out = np.stack((raw[R[0][0]::2, R[1][0]::2], #RGBG + raw[G1[0][0]::2, G1[1][0]::2], + raw[B[0][0]::2, B[1][0]::2], + raw[G2[0][0]::2, G2[1][0]::2]), axis=0).astype(np.uint16) + + return out + + +def depack_raw_bayer(raw: np.ndarray, raw_pattern: np.ndarray): + _, H, W = raw.shape + # raw = raw.astype(np.uint16) + raw = raw.astype(np.float64) + + R = np.where(raw_pattern==0) + G1 = np.where(raw_pattern==1) + B = np.where(raw_pattern==2) + G2 = np.where(raw_pattern==3) + + raw_flatten = np.zeros((H * 2, W * 2)) + raw_flatten[R[0][0]::2, R[1][0]::2] = raw[0] + raw_flatten[G1[0][0]::2, G1[1][0]::2] = raw[1] + raw_flatten[B[0][0]::2, B[1][0]::2] = raw[2] + raw_flatten[G2[0][0]::2, G2[1][0]::2] = raw[3] + + # raw_flatten = raw_flatten.astype(np.uint16) + raw_flatten = raw_flatten.astype(np.float64) + + return raw_flatten + + +def metainfo(rawpath): + with open(rawpath, 'rb') as f: + tags = exifread.process_file(f) + _, suffix = os.path.splitext(os.path.basename(rawpath)) + + if suffix == '.dng': + expo = eval(str(tags['Image ExposureTime'])) + iso = eval(str(tags['Image ISOSpeedRatings'])) + else: + expo = eval(str(tags['EXIF ExposureTime'])) + iso = eval(str(tags['EXIF ISOSpeedRatings'])) + + # print('ISO: {}, ExposureTime: {}'.format(iso, expo)) + return iso, expo + + + +def illuminance_correct(x, y): + x_m = x.mean(dim=(-1, -2)) + y_m = y.mean(dim=(-1, -2)) + xy_m = (x * y).mean(dim=(-1, -2)) + xx_m = (x * x).mean(dim=(-1, -2)) + a = (xy_m - x_m * y_m) / (xx_m - x_m * x_m) + b = y_m - a * x_m + return a.reshape(1, -1, 1, 1) * x + b.reshape(1, -1, 1, 1) + +def resize_image(img, target_shape, is_mask=False): + h, w = target_shape + interpolation = cv2.INTER_NEAREST if is_mask else cv2.INTER_AREA + return cv2.resize(img, (w, h), interpolation=interpolation) + +def load_image(path, target_size=None): + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img is None: + raise FileNotFoundError(f"Cannot read image {path}") + + if len(img.shape) == 2: + img = np.expand_dims(img, axis=-1) + + if img.shape[2] == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + if target_size is not None: + img = cv2.resize(img, (target_size[1], target_size[0]), interpolation=cv2.INTER_AREA) + + return img.astype(np.float32) + +def image_to_tensor(img): + return torch.from_numpy(img).permute(2,0,1).unsqueeze(0) + + + diff --git a/ultraled/losses/__init__.py b/ultraled/losses/__init__.py new file mode 100644 index 0000000..45fcadb --- /dev/null +++ b/ultraled/losses/__init__.py @@ -0,0 +1,31 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from ultraled.utils import get_root_logger, scandir +from ultraled.utils.registry import LOSS_REGISTRY +from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty + +__all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] + +# automatically scan and import loss modules for registry +# scan all the files under the 'losses' folder and collect files ending with '_loss.py' +loss_folder = osp.dirname(osp.abspath(__file__)) +loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] +# import all the loss modules +_model_modules = [importlib.import_module(f'ultraled.losses.{file_name}') for file_name in loss_filenames] + + +def build_loss(opt): + """Build loss from options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + loss_type = opt.pop('type') + loss = LOSS_REGISTRY.get(loss_type)(**opt) + logger = get_root_logger() + logger.info(f'Loss [{loss.__class__.__name__}] is created.') + return loss diff --git a/ultraled/losses/__pycache__/__init__.cpython-38.pyc b/ultraled/losses/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..516a2de Binary files /dev/null and b/ultraled/losses/__pycache__/__init__.cpython-38.pyc differ diff --git a/ultraled/losses/__pycache__/basic_loss.cpython-38.pyc b/ultraled/losses/__pycache__/basic_loss.cpython-38.pyc new file mode 100644 index 0000000..55c1d17 Binary files /dev/null and b/ultraled/losses/__pycache__/basic_loss.cpython-38.pyc differ diff --git a/ultraled/losses/__pycache__/gan_loss.cpython-38.pyc b/ultraled/losses/__pycache__/gan_loss.cpython-38.pyc new file mode 100644 index 0000000..5a0560c Binary files /dev/null and b/ultraled/losses/__pycache__/gan_loss.cpython-38.pyc differ diff --git a/ultraled/losses/__pycache__/loss_util.cpython-38.pyc b/ultraled/losses/__pycache__/loss_util.cpython-38.pyc new file mode 100644 index 0000000..aeb4bdd Binary files /dev/null and b/ultraled/losses/__pycache__/loss_util.cpython-38.pyc differ diff --git a/ultraled/losses/archive/conditional_loss.py b/ultraled/losses/archive/conditional_loss.py new file mode 100644 index 0000000..dd5cb27 --- /dev/null +++ b/ultraled/losses/archive/conditional_loss.py @@ -0,0 +1,55 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from ultraled.utils.registry import LOSS_REGISTRY + +@LOSS_REGISTRY.register() +class EMAConditionalLoss(nn.Module): + def __init__(self, branch_num, ema_decay=0.999, ema_time=1, core='l1', *, conv_count=1, reduce='mean', loss_weight=1.0, init='eye') -> None: + super().__init__() + if init == 'eye': + gt = torch.eye(branch_num) + elif init == 'rand': + gt = torch.rand((branch_num, branch_num)) + gt = torch.softmax(gt, 1) + elif init == 'randn': + gt = torch.randn((branch_num, branch_num)) + gt = torch.softmax(gt, 1) + # gt = self.label_smoothing(torch.eye(branch_num)) + gt = gt.repeat(1, conv_count) + self.gts = nn.parameter.Parameter(gt, requires_grad=False) + self.ema_decay = ema_decay + self.ema_time = ema_time + self.counter = nn.parameter.Parameter(torch.zeros(branch_num, dtype=torch.int), requires_grad=False) + + assert reduce in ['mean', 'sum'] + reduce_fn = eval(f'torch.{reduce}') + if core == 'l1': + self.loss = lambda x, y: reduce_fn(torch.abs(x - y)) + elif core == 'l2' or core == 'mse': + self.loss = lambda x, y: reduce_fn((x - y) * (x - y)) + elif core == 'cross_entropy': + self.loss = lambda x, y: F.cross_entropy(x, y, reduce=reduce) + else: + raise NotImplementedError() + + self.loss_weight = loss_weight + + def label_smoothing(self, x): + x[x == 1] = 0.8 + x[x == 0] = 0.05 + return x + + def ema(self, x, index): + self.counter[index] += 1 + if self.counter[index] % self.ema_time == 0: + self.counter[index] = 0 + self.gts[index].data.mul_(self.ema_decay).add_(x, alpha=1-self.ema_decay) + print(self.gts) + + def forward(self, x, index): + loss = self.loss(x, self.gts[index]) + if self.ema_decay > 0: + self.ema(x, index) + return loss * self.loss_weight diff --git a/ultraled/losses/basic_loss.py b/ultraled/losses/basic_loss.py new file mode 100644 index 0000000..c0a0970 --- /dev/null +++ b/ultraled/losses/basic_loss.py @@ -0,0 +1,380 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +# from basicsr.archs.vgg_arch import VGGFeatureExtractor +from ultraled.utils.registry import LOSS_REGISTRY +from .loss_util import weighted_loss + +_reduction_modes = ['none', 'mean', 'sum'] + + +@weighted_loss +def l1_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss(pred, target): + return F.mse_loss(pred, target, reduction='none') + + +@weighted_loss +def charbonnier_loss(pred, target, eps=1e-12): + return torch.sqrt((pred - target)**2 + eps) + +@weighted_loss +def raw_loss_sqrt(pred, target): + return F.l1_loss(pred, target, reduction='none') / torch.sqrt(target + 1.0e-8) + +@weighted_loss +def raw_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') / (target + 1.0e-8) + + +@LOSS_REGISTRY.register() +class TVLoss(nn.Module): + """Structure-Aware Total Variation Loss. + + Args: + loss_weight (float): Loss weight for TV loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + alpha (float): Exponent for gradient weighting. Default: 1.2. + lamda (float): Scaling factor for gradient weighting. Default: 1.5. + """ + + def __init__(self, loss_weight=1.0, reduction='mean', alpha=1.2, lamda=1.5): + super(TVLoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + self.alpha = alpha + self.lamda = lamda + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, 3, H, W). Ground truth tensor (RGB image). + weight (Tensor, optional): of shape (N, 1, H, W). Element-wise weights. Default: None. + """ + + weights = torch.tensor([0.2989, 0.2935, 0.1140, 0.2935], device=pred.device).view(1, 4, 1, 1) + I = torch.sum(pred * weights, dim=1, keepdim=True) # (N, 1, H, W) + L = torch.log(I + 0.0001) + dx = L[:, :, :-1, :-1] - L[:, :, :-1, 1:] # (N, 1, H-1, W-1) + dy = L[:, :, :-1, :-1] - L[:, :, 1:, :-1] # (N, 1, H-1, W-1) + + + dx_weight = self.lamda / (torch.abs(dx).pow(self.alpha) + 0.0001) + dy_weight = self.lamda / (torch.abs(dy).pow(self.alpha) + 0.0001) + x_diff = target[:, :, :-1, :-1] - target[:, :, :-1, 1:] # (N, C, H-1, W-1) + y_diff = target[:, :, :-1, :-1] - target[:, :, 1:, :-1] # (N, C, H-1, W-1) + print(x_diff.shape, dx_weight.shape, dx.shape) + x_loss = dx_weight * (x_diff ** 2) + y_loss = dy_weight * (y_diff ** 2) + loss = (x_loss + y_loss) / 2.0 + + # Apply reduction + if self.reduction == 'mean': + loss = loss.mean() + elif self.reduction == 'sum': + loss = loss.sum() + elif self.reduction == 'none': + pass # Keep the original shape + + # Apply weight if provided + if weight is not None: + loss = loss * weight + + return self.loss_weight * loss + + +@LOSS_REGISTRY.register() +class RAWSQRTL1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(RAWSQRTL1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + + return self.loss_weight * raw_loss_sqrt(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class RAWL1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(RAWL1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + + return self.loss_weight * raw_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class MSELoss(nn.Module): + """MSE (L2) loss. + + Args: + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(MSELoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class CharbonnierLoss(nn.Module): + """Charbonnier loss (one variant of Robust L1Loss, a differentiable + variant of L1Loss). + + Described in "Deep Laplacian Pyramid Networks for Fast and Accurate + Super-Resolution". + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + eps (float): A value used to control the curvature near zero. Default: 1e-12. + """ + + def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12): + super(CharbonnierLoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + self.eps = eps + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class WeightedTVLoss(L1Loss): + """Weighted TV loss. + + Args: + loss_weight (float): Loss weight. Default: 1.0. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + if reduction not in ['mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum') + super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction) + + def forward(self, pred, weight=None): + if weight is None: + y_weight = None + x_weight = None + else: + y_weight = weight[:, :, :-1, :] + x_weight = weight[:, :, :, :-1] + + y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight) + x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight) + + loss = x_diff + y_diff + + return loss + + +@LOSS_REGISTRY.register() +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculating losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == 'l2': + self.criterion = torch.nn.L2loss() + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError(f'{criterion} criterion has not been supported.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += torch.norm( + self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( + gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram diff --git a/ultraled/losses/gan_loss.py b/ultraled/losses/gan_loss.py new file mode 100644 index 0000000..d1afc90 --- /dev/null +++ b/ultraled/losses/gan_loss.py @@ -0,0 +1,208 @@ +import math +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F + +from ultraled.utils.registry import LOSS_REGISTRY + + +@LOSS_REGISTRY.register() +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + target_label = self.get_target_label(input, target_is_real) + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +@LOSS_REGISTRY.register() +class MultiScaleGANLoss(GANLoss): + """ + MultiScaleGANLoss accepts a list of predictions + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) + + def forward(self, input, target_is_real, is_disc=False): + """ + The input is a list of tensors, or a list of (a list of tensors) + """ + if isinstance(input, list): + loss = 0 + for pred_i in input: + if isinstance(pred_i, list): + # Only compute GAN loss for the last layer + # in case of multiscale feature matching + pred_i = pred_i[-1] + # Safe operation: 0-dim tensor calling self.mean() does nothing + loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() + loss += loss_tensor + return loss / len(input) + else: + return super().forward(input, target_is_real, is_disc) + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) + grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): + """Calculate gradient penalty for wgan-gp. + + Args: + discriminator (nn.Module): Network for the discriminator. + real_data (Tensor): Real input data. + fake_data (Tensor): Fake input data. + weight (Tensor): Weight tensor. Default: None. + + Returns: + Tensor: A tensor for gradient penalty. + """ + + batch_size = real_data.size(0) + alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) + + # interpolate between real_data and fake_data + interpolates = alpha * real_data + (1. - alpha) * fake_data + interpolates = autograd.Variable(interpolates, requires_grad=True) + + disc_interpolates = discriminator(interpolates) + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + if weight is not None: + gradients = gradients * weight + + gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() + if weight is not None: + gradients_penalty /= torch.mean(weight) + + return gradients_penalty diff --git a/ultraled/losses/loss_util.py b/ultraled/losses/loss_util.py new file mode 100644 index 0000000..fd293ff --- /dev/null +++ b/ultraled/losses/loss_util.py @@ -0,0 +1,145 @@ +import functools +import torch +from torch.nn import functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are 'none', 'mean' and 'sum'. + + Returns: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + else: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean'): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. Default: None. + reduction (str): Same as built-in losses of PyTorch. Options are + 'none', 'mean' and 'sum'. Default: 'mean'. + + Returns: + Tensor: Loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if weight is not specified or reduction is sum, just reduce the loss + if weight is None or reduction == 'sum': + loss = reduce_loss(loss, reduction) + # if reduction is mean, then compute mean over weight region + elif reduction == 'mean': + if weight.size(1) > 1: + weight = weight.sum() + else: + weight = weight.sum() * loss.size(1) + loss = loss.sum() / weight + + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.5000) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, reduction='sum') + tensor(3.) + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction='mean', **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction) + return loss + + return wrapper + + +def get_local_weights(residual, ksize): + """Get local weights for generating the artifact map of LDL. + + It is only called by the `get_refined_artifact_map` function. + + Args: + residual (Tensor): Residual between predicted and ground truth images. + ksize (Int): size of the local window. + + Returns: + Tensor: weight for each pixel to be discriminated as an artifact pixel + """ + + pad = (ksize - 1) // 2 + residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect') + + unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1) + pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1) + + return pixel_level_weight + + +def get_refined_artifact_map(img_gt, img_output, img_ema, ksize): + """Calculate the artifact map of LDL + (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022) + + Args: + img_gt (Tensor): ground truth images. + img_output (Tensor): output images given by the optimizing model. + img_ema (Tensor): output images given by the ema model. + ksize (Int): size of the local window. + + Returns: + overall_weight: weight for each pixel to be discriminated as an artifact pixel + (calculated based on both local and global observations). + """ + + residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True) + residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True) + + patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5) + pixel_level_weight = get_local_weights(residual_sr.clone(), ksize) + overall_weight = patch_level_weight * pixel_level_weight + + overall_weight[residual_sr < residual_ema] = 0 + + return overall_weight diff --git a/ultraled/metrics/__init__.py b/ultraled/metrics/__init__.py new file mode 100644 index 0000000..8c50e94 --- /dev/null +++ b/ultraled/metrics/__init__.py @@ -0,0 +1,20 @@ +from copy import deepcopy + +from ultraled.utils.registry import METRIC_REGISTRY +from .niqe import calculate_niqe +from .psnr_ssim import calculate_psnr, calculate_ssim + +__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] + + +def calculate_metric(data, opt): + """Calculate metric from data and options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + metric_type = opt.pop('type') + metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) + return metric diff --git a/ultraled/metrics/__pycache__/__init__.cpython-38.pyc b/ultraled/metrics/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..36b49aa Binary files /dev/null and b/ultraled/metrics/__pycache__/__init__.cpython-38.pyc differ diff --git a/ultraled/metrics/__pycache__/metric_util.cpython-38.pyc b/ultraled/metrics/__pycache__/metric_util.cpython-38.pyc new file mode 100644 index 0000000..e5ebb80 Binary files /dev/null and b/ultraled/metrics/__pycache__/metric_util.cpython-38.pyc differ diff --git a/ultraled/metrics/__pycache__/niqe.cpython-38.pyc b/ultraled/metrics/__pycache__/niqe.cpython-38.pyc new file mode 100644 index 0000000..4f3242e Binary files /dev/null and b/ultraled/metrics/__pycache__/niqe.cpython-38.pyc differ diff --git a/ultraled/metrics/__pycache__/psnr_ssim.cpython-38.pyc b/ultraled/metrics/__pycache__/psnr_ssim.cpython-38.pyc new file mode 100644 index 0000000..123f8fc Binary files /dev/null and b/ultraled/metrics/__pycache__/psnr_ssim.cpython-38.pyc differ diff --git a/ultraled/metrics/fid.py b/ultraled/metrics/fid.py new file mode 100644 index 0000000..2b926eb --- /dev/null +++ b/ultraled/metrics/fid.py @@ -0,0 +1,93 @@ +import numpy as np +import torch +import torch.nn as nn +from scipy import linalg +from tqdm import tqdm + +from basicsr.archs.inception import InceptionV3 + + +def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False): + # we may not resize the input, but in [rosinality/stylegan2-pytorch] it + # does resize the input. + inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input) + inception = nn.DataParallel(inception).eval().to(device) + return inception + + +@torch.no_grad() +def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'): + """Extract inception features. + + Args: + data_generator (generator): A data generator. + inception (nn.Module): Inception model. + len_generator (int): Length of the data_generator to show the + progressbar. Default: None. + device (str): Device. Default: cuda. + + Returns: + Tensor: Extracted features. + """ + if len_generator is not None: + pbar = tqdm(total=len_generator, unit='batch', desc='Extract') + else: + pbar = None + features = [] + + for data in data_generator: + if pbar: + pbar.update(1) + data = data.to(device) + feature = inception(data)[0].view(data.shape[0], -1) + features.append(feature.to('cpu')) + if pbar: + pbar.close() + features = torch.cat(features, 0) + return features + + +def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + + Args: + mu1 (np.array): The sample mean over activations. + sigma1 (np.array): The covariance matrix over activations for + generated samples. + mu2 (np.array): The sample mean over activations, precalculated on an + representative data set. + sigma2 (np.array): The covariance matrix over activations, + precalculated on an representative data set. + + Returns: + float: The Frechet Distance. + """ + assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions') + + cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) + + # Product might be almost singular + if not np.isfinite(cov_sqrt).all(): + print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates') + offset = np.eye(sigma1.shape[0]) * eps + cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(cov_sqrt): + if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): + m = np.max(np.abs(cov_sqrt.imag)) + raise ValueError(f'Imaginary component {m}') + cov_sqrt = cov_sqrt.real + + mean_diff = mu1 - mu2 + mean_norm = mean_diff @ mean_diff + trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) + fid = mean_norm + trace + + return fid diff --git a/ultraled/metrics/metric_util.py b/ultraled/metrics/metric_util.py new file mode 100644 index 0000000..52a84c2 --- /dev/null +++ b/ultraled/metrics/metric_util.py @@ -0,0 +1,45 @@ +import numpy as np + +from ultraled.utils import bgr2ycbcr + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/ultraled/metrics/niqe.py b/ultraled/metrics/niqe.py new file mode 100644 index 0000000..07616de --- /dev/null +++ b/ultraled/metrics/niqe.py @@ -0,0 +1,197 @@ +import cv2 +import math +import numpy as np +import os +from scipy.ndimage.filters import convolve +from scipy.special import gamma + +from ultraled.metrics.metric_util import reorder_image, to_y_channel +from ultraled.utils.matlab_functions import imresize +from ultraled.utils.registry import METRIC_REGISTRY + + +def estimate_aggd_param(block): + """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters. + + Args: + block (ndarray): 2D Image block. + + Returns: + tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD + distribution (Estimating the parames in Equation 7 in the paper). + """ + block = block.flatten() + gam = np.arange(0.2, 10.001, 0.001) # len = 9801 + gam_reciprocal = np.reciprocal(gam) + r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3)) + + left_std = np.sqrt(np.mean(block[block < 0]**2)) + right_std = np.sqrt(np.mean(block[block > 0]**2)) + gammahat = left_std / right_std + rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2) + rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2) + array_position = np.argmin((r_gam - rhatnorm)**2) + + alpha = gam[array_position] + beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + return (alpha, beta_l, beta_r) + + +def compute_feature(block): + """Compute features. + + Args: + block (ndarray): 2D Image block. + + Returns: + list: Features with length of 18. + """ + feat = [] + alpha, beta_l, beta_r = estimate_aggd_param(block) + feat.extend([alpha, (beta_l + beta_r) / 2]) + + # distortions disturb the fairly regular structure of natural images. + # This deviation can be captured by analyzing the sample distribution of + # the products of pairs of adjacent coefficients computed along + # horizontal, vertical and diagonal orientations. + shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] + for i in range(len(shifts)): + shifted_block = np.roll(block, shifts[i], axis=(0, 1)) + alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block) + # Eq. 8 + mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha)) + feat.extend([alpha, mean, beta_l, beta_r]) + return feat + + +def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + Ref: Making a "Completely Blind" Image Quality Analyzer. + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + Note that we do not include block overlap height and width, since they are + always 0 in the official implementation. + + For good performance, it is advisable by the official implementation to + divide the distorted image in to the same size patched as used for the + construction of multivariate Gaussian model. + + Args: + img (ndarray): Input image whose quality needs to be computed. The + image must be a gray or Y (of YCbCr) image with shape (h, w). + Range [0, 255] with float type. + mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian + model calculated on the pristine dataset. + cov_pris_param (ndarray): Covariance of a pre-defined multivariate + Gaussian model calculated on the pristine dataset. + gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the + image. + block_size_h (int): Height of the blocks in to which image is divided. + Default: 96 (the official recommended value). + block_size_w (int): Width of the blocks in to which image is divided. + Default: 96 (the official recommended value). + """ + assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).') + # crop image + h, w = img.shape + num_block_h = math.floor(h / block_size_h) + num_block_w = math.floor(w / block_size_w) + img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w] + + distparam = [] # dist param is actually the multiscale features + for scale in (1, 2): # perform on two scales (1, 2) + mu = convolve(img, gaussian_window, mode='nearest') + sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu))) + # normalize, as in Eq. 1 in the paper + img_nomalized = (img - mu) / (sigma + 1) + + feat = [] + for idx_w in range(num_block_w): + for idx_h in range(num_block_h): + # process ecah block + block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale, + idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale] + feat.append(compute_feature(block)) + + distparam.append(np.array(feat)) + + if scale == 1: + img = imresize(img / 255., scale=0.5, antialiasing=True) + img = img * 255. + + distparam = np.concatenate(distparam, axis=1) + + # fit a MVG (multivariate Gaussian) model to distorted patch features + mu_distparam = np.nanmean(distparam, axis=0) + # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html + distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)] + cov_distparam = np.cov(distparam_no_nan, rowvar=False) + + # compute niqe quality, Eq. 10 in the paper + invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) + quality = np.matmul( + np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam))) + + quality = np.sqrt(quality) + quality = float(np.squeeze(quality)) + return quality + + +@METRIC_REGISTRY.register() +def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y', **kwargs): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + Ref: Making a "Completely Blind" Image Quality Analyzer. + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + > MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296) + > Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296) + + We use the official params estimated from the pristine dataset. + We use the recommended block size (96, 96) without overlaps. + + Args: + img (ndarray): Input image whose quality needs to be computed. + The input image must be in range [0, 255] with float/int type. + The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order) + If the input order is 'HWC' or 'CHW', it will be converted to gray + or Y (of YCbCr) image according to the ``convert_to`` argument. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the metric calculation. + input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'. + Default: 'HWC'. + convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'. + Default: 'y'. + + Returns: + float: NIQE result. + """ + ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + # we use the official params estimated from the pristine dataset. + niqe_pris_params = np.load(os.path.join(ROOT_DIR, 'niqe_pris_params.npz')) + mu_pris_param = niqe_pris_params['mu_pris_param'] + cov_pris_param = niqe_pris_params['cov_pris_param'] + gaussian_window = niqe_pris_params['gaussian_window'] + + img = img.astype(np.float32) + if input_order != 'HW': + img = reorder_image(img, input_order=input_order) + if convert_to == 'y': + img = to_y_channel(img) + elif convert_to == 'gray': + img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255. + img = np.squeeze(img) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border] + + # round is necessary for being consistent with MATLAB's result + img = img.round() + + niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window) + + return niqe_result diff --git a/ultraled/metrics/psnr_ssim.py b/ultraled/metrics/psnr_ssim.py new file mode 100644 index 0000000..439ad7d --- /dev/null +++ b/ultraled/metrics/psnr_ssim.py @@ -0,0 +1,233 @@ +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +from ultraled.metrics.metric_util import reorder_image, to_y_channel +from ultraled.utils.color_util import rgb2ycbcr_pt +from ultraled.utils.registry import METRIC_REGISTRY + + +@METRIC_REGISTRY.register() +def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 10. * np.log10(255. * 255. / mse) + + +@METRIC_REGISTRY.register() +def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + mse = torch.mean((img - img2)**2, dim=[1, 2, 3]) + return 10. * torch.log10(1. / (mse + 1e-8)) + + +@METRIC_REGISTRY.register() +def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + ssims = [] + for i in range(img.shape[2]): + ssims.append(_ssim(img[..., i], img2[..., i])) + return np.array(ssims).mean() + + +@METRIC_REGISTRY.register() +def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity) (PyTorch version). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + ssim = _ssim_pth(img * 255., img2 * 255.) + return ssim + + +def _ssim(img, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: SSIM result. + """ + + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11 + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) + return ssim_map.mean() + + +def _ssim_pth(img, img2): + """Calculate SSIM (structural similarity) (PyTorch version). + + It is called by func:`calculate_ssim_pt`. + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + + Returns: + float: SSIM result. + """ + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device) + + mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1]) # valid mode + mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1]) # valid mode + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq + sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2 + + cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2) + ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map + return ssim_map.mean([1, 2, 3]) diff --git a/ultraled/metrics/test_metrics/test_psnr_ssim.py b/ultraled/metrics/test_metrics/test_psnr_ssim.py new file mode 100644 index 0000000..eab898b --- /dev/null +++ b/ultraled/metrics/test_metrics/test_psnr_ssim.py @@ -0,0 +1,52 @@ +import cv2 +import torch + +from ultraled.metrics import calculate_psnr, calculate_ssim +from ultraled.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt +from ultraled.utils import img2tensor + + +def test(img_path, img_path2, crop_border, test_y_channel=False): + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED) + + # --------------------- Numpy --------------------- + psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) + ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) + print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}') + + # --------------------- PyTorch (CPU) --------------------- + img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0) + img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0) + + psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) + ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) + print(f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') + + # --------------------- PyTorch (GPU) --------------------- + img = img.cuda() + img2 = img2.cuda() + psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) + ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) + print(f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') + + psnr_pth = calculate_psnr_pt( + torch.repeat_interleave(img, 2, dim=0), + torch.repeat_interleave(img2, 2, dim=0), + crop_border=crop_border, + test_y_channel=test_y_channel) + ssim_pth = calculate_ssim_pt( + torch.repeat_interleave(img, 2, dim=0), + torch.repeat_interleave(img2, 2, dim=0), + crop_border=crop_border, + test_y_channel=test_y_channel) + print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,' + f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}') + + +if __name__ == '__main__': + test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=False) + test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=True) + + test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=False) + test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=True) diff --git a/ultraled/models/__init__.py b/ultraled/models/__init__.py new file mode 100644 index 0000000..4911a3c --- /dev/null +++ b/ultraled/models/__init__.py @@ -0,0 +1,29 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from ultraled.utils import get_root_logger, scandir +from ultraled.utils.registry import MODEL_REGISTRY + +__all__ = ['build_model'] + +# automatically scan and import model modules for registry +# scan all the files under the 'models' folder and collect files ending with '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'ultraled.models.{file_name}') for file_name in model_filenames] + + +def build_model(opt): + """Build model from options. + + Args: + opt (dict): Configuration. It must contain: + model_type (str): Model type. + """ + opt = deepcopy(opt) + model = MODEL_REGISTRY.get(opt['model_type'])(opt) + logger = get_root_logger() + logger.info(f'Model [{model.__class__.__name__}] is created.') + return model diff --git a/ultraled/models/__pycache__/__init__.cpython-38.pyc b/ultraled/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..07a33bd Binary files /dev/null and b/ultraled/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/ultraled/models/__pycache__/base_model.cpython-38.pyc b/ultraled/models/__pycache__/base_model.cpython-38.pyc new file mode 100644 index 0000000..cd0f079 Binary files /dev/null and b/ultraled/models/__pycache__/base_model.cpython-38.pyc differ diff --git a/ultraled/models/__pycache__/lr_scheduler.cpython-38.pyc b/ultraled/models/__pycache__/lr_scheduler.cpython-38.pyc new file mode 100644 index 0000000..e13d557 Binary files /dev/null and b/ultraled/models/__pycache__/lr_scheduler.cpython-38.pyc differ diff --git a/ultraled/models/__pycache__/raw_denoising_model.cpython-38.pyc b/ultraled/models/__pycache__/raw_denoising_model.cpython-38.pyc new file mode 100644 index 0000000..f87e66b Binary files /dev/null and b/ultraled/models/__pycache__/raw_denoising_model.cpython-38.pyc differ diff --git a/ultraled/models/base_model.py b/ultraled/models/base_model.py new file mode 100644 index 0000000..4bba8d3 --- /dev/null +++ b/ultraled/models/base_model.py @@ -0,0 +1,392 @@ +import os +import time +import torch +from collections import OrderedDict +from copy import deepcopy +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from ultraled.models import lr_scheduler as lr_scheduler +from ultraled.utils import get_root_logger +from ultraled.utils.dist_util import master_only + + +class BaseModel(): + """Base model.""" + + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def save(self, epoch, current_iter): + """Save networks and training state.""" + pass + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """Validation function. + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def _initialize_best_metric_results(self, dataset_name): + """Initialize the best metric results dict for recording the best metric value and iteration.""" + if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results: + return + elif not hasattr(self, 'best_metric_results'): + self.best_metric_results = dict() + + # add a dataset record + record = dict() + for metric, content in self.opt['val']['metrics'].items(): + better = content.get('better', 'higher') + init_val = float('-inf') if better == 'higher' else float('inf') + record[metric] = dict(better=better, val=init_val, iter=-1) + self.best_metric_results[dataset_name] = record + + def _update_best_metric_result(self, dataset_name, metric, val, current_iter): + if self.best_metric_results[dataset_name][metric]['better'] == 'higher': + if val >= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + else: + if val <= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + + def model_ema(self, decay=0.999): + net_g = self.get_bare_model(self.net_g) + + net_g_params = dict(net_g.named_parameters()) + net_g_ema_params = dict(self.net_g_ema.named_parameters()) + + for k in net_g_ema_params.keys(): + net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay) + + def get_current_log(self): + return self.log_dict + + def model_to_device(self, net): + """Model to device. It also warps models with DistributedDataParallel + or DataParallel. + + Args: + net (nn.Module) + """ + net = net.to(self.device) + if self.opt['dist']: + find_unused_parameters = self.opt.get('find_unused_parameters', False) + net = DistributedDataParallel( + net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) + elif self.opt['num_gpu'] > 1: + net = DataParallel(net) + return net + + def get_optimizer(self, optim_type, params, lr, **kwargs): + if optim_type == 'Adam': + optimizer = torch.optim.Adam(params, lr, **kwargs) + elif optim_type == 'AdamW': + optimizer = torch.optim.AdamW(params, lr, **kwargs) + else: + raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.') + return optimizer + + def setup_schedulers(self): + """Set up schedulers.""" + train_opt = self.opt['train'] + scheduler_type = train_opt['scheduler'].pop('type') + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'HandieLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.HandieLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'HandieStepLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.HandieStepLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'TorchCosineAnnealingLR': + print('Torch', 'CosineAnnealingLR') + for optimizer in self.optimizers: + self.schedulers.append(torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **train_opt['scheduler'])) + else: + raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with + DistributedDataParallel or DataParallel. + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + return net + + @master_only + def print_network(self, net): + """Print the str and parameter number of a network. + + Args: + net (nn.Module) + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}' + else: + net_cls_str = f'{net.__class__.__name__}' + + net = self.get_bare_model(net) + net_str = str(net) + net_params = sum(map(lambda x: x.numel(), net.parameters())) + + logger = get_root_logger() + logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}') + logger.info(net_str) + + def _set_lr(self, lr_groups_l): + """Set learning rate for warmup. + + Args: + lr_groups_l (list): List for lr_groups, each for an optimizer. + """ + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler. + """ + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, current_iter, warmup_iter=-1): + """Update learning rate. + + Args: + current_iter (int): Current iteration. + warmup_iter (int): Warmup iter numbers. -1 for no warmup. + Default: -1. + """ + if current_iter > 1: + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if current_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + # currently only support linearly warm up + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + + @master_only + def save_network(self, net, net_label, current_iter, param_key='params'): + """Save networks. + + Args: + net (nn.Module | list[nn.Module]): Network(s) to be saved. + net_label (str): Network label. + current_iter (int): Current iter number. + param_key (str | list[str]): The parameter key(s) to save network. + Default: 'params'. + """ + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + save_path = os.path.join(self.opt['path']['models'], save_filename) + + net = net if isinstance(net, list) else [net] + param_key = param_key if isinstance(param_key, list) else [param_key] + assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.' + + save_dict = {} + for net_, param_key_ in zip(net, param_key): + net_ = self.get_bare_model(net_) + state_dict = net_.state_dict() + for key, param in state_dict.items(): + if key.startswith('module.'): # remove unnecessary 'module.' + key = key[7:] + state_dict[key] = param.cpu() + save_dict[param_key_] = state_dict + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(save_dict, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def _print_different_keys_loading(self, crt_net, load_net, strict=True): + """Print keys with different name or different size when loading models. + + 1. Print keys with different names. + 2. If strict=False, print the same key but with different tensor size. + It also ignore these keys with different sizes (not load). + + Args: + crt_net (torch model): Current network. + load_net (dict): Loaded network. + strict (bool): Whether strictly loaded. Default: True. + """ + crt_net = self.get_bare_model(crt_net) + crt_net = crt_net.state_dict() + crt_net_keys = set(crt_net.keys()) + load_net_keys = set(load_net.keys()) + + logger = get_root_logger() + if crt_net_keys != load_net_keys: + logger.warning('Current net - loaded net:') + for v in sorted(list(crt_net_keys - load_net_keys)): + logger.warning(f' {v}') + logger.warning('Loaded net - current net:') + for v in sorted(list(load_net_keys - crt_net_keys)): + logger.warning(f' {v}') + + # check the size for the same keys + if not strict: + common_keys = crt_net_keys & load_net_keys + for k in common_keys: + if crt_net[k].size() != load_net[k].size(): + logger.warning(f'Size different, ignore [{k}]: crt_net: ' + f'{crt_net[k].shape}; load_net: {load_net[k].shape}') + load_net[k + '.ignore'] = load_net.pop(k) + + def load_network(self, net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + logger = get_root_logger() + net = self.get_bare_model(net) + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + self._print_different_keys_loading(net, load_net, strict) + net.load_state_dict(load_net, strict=strict) + + @master_only + def save_training_state(self, epoch, current_iter): + """Save training states during training, which will be used for + resuming. + + Args: + epoch (int): Current epoch. + current_iter (int): Current iteration. + """ + if current_iter != -1: + state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []} + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + save_filename = f'{current_iter}.state' + save_path = os.path.join(self.opt['path']['training_states'], save_filename) + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(state, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def resume_training(self, resume_state): + """Reload the optimizers and schedulers for resumed training. + + Args: + resume_state (dict): Resume state. + """ + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + + def reduce_loss_dict(self, loss_dict): + """reduce loss dict. + + In distributed training, it averages the losses among different GPUs . + + Args: + loss_dict (OrderedDict): Loss dict. + """ + with torch.no_grad(): + if self.opt['dist']: + keys = [] + losses = [] + for name, value in loss_dict.items(): + keys.append(name) + losses.append(value) + losses = torch.stack(losses, 0) + torch.distributed.reduce(losses, dst=0) + if self.opt['rank'] == 0: + losses /= self.opt['world_size'] + loss_dict = {key: loss for key, loss in zip(keys, losses)} + + log_dict = OrderedDict() + for name, value in loss_dict.items(): + log_dict[name] = value.mean().item() + + return log_dict diff --git a/ultraled/models/lr_scheduler.py b/ultraled/models/lr_scheduler.py new file mode 100644 index 0000000..0ac1162 --- /dev/null +++ b/ultraled/models/lr_scheduler.py @@ -0,0 +1,147 @@ +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] + + +class HandieLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, lrs, last_epoch=-1): + self.milestones = milestones + self.lrs = lrs + super(HandieLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.milestones: + index = self.milestones.index(self.last_epoch) + return [self.lrs[index] for _ in range(len(self.optimizer.param_groups))] + return [group['lr'] for group in self.optimizer.param_groups] + + +class HandieStepLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, gammas, last_epoch=-1): + self.milestones = milestones + self.gammas = gammas + assert len(self.gammas) == len(self.milestones) + super(HandieStepLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.milestones: + index = self.milestones.index(self.last_epoch) + return [self.gammas[index] * group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] for group in self.optimizer.param_groups] + + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The minimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len( + self.restart_weights)), 'periods and restart_weights should have the same length.' + self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/ultraled/models/raw_denoising_model.py b/ultraled/models/raw_denoising_model.py new file mode 100644 index 0000000..d15e022 --- /dev/null +++ b/ultraled/models/raw_denoising_model.py @@ -0,0 +1,970 @@ +import torch +from torch import nn +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm +import os + +from copy import deepcopy +from ultraled.archs import build_network +from ultraled.losses import build_loss +from ultraled.metrics import calculate_metric +from ultraled.utils import get_root_logger, imwrite, tensor2img +from ultraled.utils.registry import MODEL_REGISTRY +from .base_model import BaseModel + +from ultraled.utils import load_CRF, raw2rgb_torch, raw2rgb_torch_grad +from ultraled.data.hdr_util import BlendMertens +from ultraled.data.noise_util_rawhdr import NoiseGenerator + +import yaml +import torch.nn.functional as F + +def sum_img_and_noise(img, noises): + for noise in noises: + img += noise + return img + +def gamma_correct(linear, eps=None): + if eps is None: + eps = torch.finfo(torch.float32).eps + srgb0 = 323 / 25 * linear + srgb1 = (211 * torch.maximum(torch.tensor(eps), linear)**(5 / 12) - 11) / 200 + return torch.where(linear <= 0.0031308, srgb0, srgb1) + + +def gamma_expansion(srgb, eps=None): + if eps is None: + eps = torch.finfo(torch.float32).eps + linear0 = 25 / 323 * srgb + linear1 = torch.maximum(torch.tensor(eps), ((200 * srgb + 11) / (211)))**(12 / 5) + return torch.where(srgb <= 0.04045, linear0, linear1) + +def half_size_demosaic(bayer_images): + r = bayer_images[..., 0:1, :, :] + gr = bayer_images[..., 1:2, :, :] + b = bayer_images[..., 2:3, :, :] + gb = bayer_images[..., 3:4, :, :] + g = (gr + gb) / 2 + linear_rgb = torch.cat([r, g, b], dim=-3) + return linear_rgb + +def apply_gains(bayer_images, wbs): + """Applies white balance to a batch of Bayer images.""" + B, N, C, _, _ = bayer_images.shape + outs = bayer_images * wbs.view(B, -1, C, 1, 1) + return outs + +def apply_ccms(images, ccms): + """Applies color correction matrices.""" + images = images.permute( + 0, 1, 3, 4, 2) # Permute the image tensor to BxHxWxC format from BxCxHxW format + images = images[:, :, :, :, None, :] + ccms = ccms[:, :, None, None, :, :] + outs = torch.sum(images * ccms, dim=-1) + # Re-Permute the tensor back to BxCxHxW format + outs = outs.permute(0, 1, 4, 2, 3) + return outs + + +def tiny_isp(im, wb, ccm): + im = half_size_demosaic(im) + im = apply_gains(im, wb) + im = apply_ccms(im, ccm) + return gamma_correct(im).clip(0, 1) + +def reverse_tiny_isp(srgb, wb, ccm): + raw = gamma_expansion(srgb) + raw = apply_ccms(raw, torch.inverse(ccm)) + raw = apply_gains(raw, 1.0 / wb) + # expand to 4 channels + raw_g = raw[..., 1:2, :, :] + raw = torch.cat([raw, raw_g], dim=-3) + return raw + +def exposure_fusion_from_raw(ims, wb, ccm, blend_menten): + """ + ims: B, N, C, H, W + wb: B, 4 + ccm: B, 3, 3 + """ + wb = wb[..., :3] # B, 1, 3 + ccm = ccm.unsqueeze(1) # B, 1, 3, 3 + srgbs = tiny_isp(ims, wb, ccm) # B, N, 3, H, W + merged = blend_menten(*[srgbs[:, i] for i in range(srgbs.shape[1])]) # B, 3, H, W + merged_raw = reverse_tiny_isp(merged.unsqueeze(1), wb, ccm).squeeze(1) # B, 4, H, W + return merged_raw + + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + tuple: yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + +def yaml_load(f): + """Load yaml file or string. + + Args: + f (str): File path or a python string. + + Returns: + dict: Loaded dict. + """ + if os.path.isfile(f): + with open(f, 'r') as f: + return yaml.load(f, Loader=ordered_yaml()[0]) + else: + return yaml.load(f, Loader=ordered_yaml()[0]) + + + +def load_network(net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + print('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + print(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + net.load_state_dict(load_net, strict=strict) + +class IlluminanceCorrect(nn.Module): + def __init__(self): + super(IlluminanceCorrect, self).__init__() + + # Illuminance Correction + def forward(self, predict, source): + if predict.shape[0] != 1: + output = torch.zeros_like(predict) + if source.shape[0] != 1: + for i in range(predict.shape[0]): + output[i:i+1, ...] = self.correct(predict[i:i+1, ...], source[i:i+1, ...]) + else: + for i in range(predict.shape[0]): + output[i:i+1, ...] = self.correct(predict[i:i+1, ...], source) + else: + output = self.correct(predict, source) + return output + + def correct(self, predict, source): + N, C, H, W = predict.shape + predict = torch.clamp(predict, 0, 1) + assert N == 1 + output = torch.zeros_like(predict, device=predict.device) + pred_c = predict[source != 1] + source_c = source[source != 1] + + num = torch.dot(pred_c, source_c) + den = torch.dot(pred_c, pred_c) + output = num / den * predict + + return output + + +@MODEL_REGISTRY.register() +class RatioMapEstimatorModel(BaseModel): + + def __init__(self, opt): + super(RatioMapEstimatorModel, self).__init__(opt) + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + self.blend_merten = BlendMertens(contrast_weight=1.0, saturation_weight=1.0, exposure_weight=1.0, clip=True) + self.noise_gen = NoiseGenerator(**self.opt['noise_g']) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + if self.opt.get('CRF_path', None) is not None: + self.CRF = load_CRF(self.opt['CRF_path']) + else: + self.CRF = None + + self.correct = self.opt['val'].get('illumination_correct', False) + if self.correct: + self.corrector = IlluminanceCorrect() + self.metric_in_srgb = self.opt.get('metric_in_srgb', False) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('srgb_opt'): + self.post_process = lambda x, wb, ccm: raw2rgb_torch_grad(x, wb, ccm, self.CRF) + else: + self.post_process = lambda x, wb, ccm: x + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + learnable_layers = train_opt.get('learnable_layers', None) + learnable_keys = train_opt.get('learnable_keys', None) + optim_params = [] + if learnable_layers is None and learnable_keys is None: + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + else: + if learnable_keys is not None: + logger = get_root_logger() + logger.info(f'Using \'learnable_keys\' for query training paarameters ...') + for k, v in self.net_g.named_parameters(): + for l_key in learnable_keys: + if l_key in k: + optim_params.append(v) + break + assert len(optim_params) > 0 + if learnable_layers is not None: + logger = get_root_logger() + logger.info(f'Using \'learnable_layers\' for query training paarameters ...') + for layer in learnable_layers: + if hasattr(self.net_g, layer): + optim_params.extend(list(eval(f'self.net_g.{layer}').parameters())) + else: + logger = get_root_logger() + logger.error(f'Layer {layer} is not in {self.net_g.__name__}.') + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def feed_data(self, data): + + self.intact = data['intact'].to(self.device) + self.lq_clean = data['lq_clean'].to(self.device) + self.ccm = data['ccm'].to(self.device) + self.wb = data['wb'].to(self.device) + self.ratio = data['ratio'].to(self.device) + self.ratio_all = data['ratio1'].to(self.device) + + mexp_lq = data['gt'].to(self.device) + fused_raw = torch.clamp(exposure_fusion_from_raw(mexp_lq, self.wb, self.ccm, self.blend_merten), 1e-8) + self.gt = self.lq_clean / fused_raw + self.ratio_all = self.ratio_all.unsqueeze(1).unsqueeze(1).unsqueeze(1) + self.gt = self.gt.clip(1 / self.ratio_all, self.ratio_all) + + # add noise + lq_im_patch = torch.clamp(self.lq_clean , min=0) * (16383 - 512) / self.ratio_all + im_patch, noise1 = self.noise_gen(lq_im_patch) + lq_im_patch = sum_img_and_noise(im_patch, noise1) / (16383 - 512) * self.ratio_all + + self.lq = lq_im_patch + + + def optimize_parameters(self, current_iter): + self.optimizer_g.zero_grad() + index = self.opt['network_g'] + net = index['type'] + + if str(net) == 'UNetArch' or str(net) == 'Restormer': + self.output = self.net_g(self.lq) + else: + self.output = self.net_g(self.lq, self.ratiomap) + + self.output = self.post_process(self.output, self.wb, self.ccm) + self.gt = self.post_process(self.gt, self.wb, self.ccm) + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.cri_pix: + + l_pix = self.cri_pix(self.output, self.gt) + l_total += l_pix + loss_dict['l_pix'] = l_pix + # perceptual loss + if self.cri_perceptual: + l_percep, l_style = self.cri_perceptual(self.output, self.gt) + if l_percep is not None: + l_total += l_percep + loss_dict['l_percep'] = l_percep + if l_style is not None: + l_total += l_style + loss_dict['l_style'] = l_style + + l_total.backward() + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def test(self): + # padding + h, w = self.lq.shape[2:] + pad_h = 16 - (h % 16) if h % 16 != 0 else 0 + pad_w = 16 - (w % 16) if w % 16 != 0 else 0 + + self.gt = self.gt.squeeze(0) + self.lq = nn.functional.pad(self.lq, [0, pad_w, 0, pad_h], mode='replicate') + self.gt = nn.functional.pad(self.gt, [0, pad_w, 0, pad_h], mode='replicate') + + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + # illumination correction + if self.correct: + self.output = self.corrector(self.output, self.gt) + else: + self.net_g.eval() + with torch.no_grad(): + index = self.opt['network_g'] + net = index['type'] + if str(net) == 'UNetArch': + self.output = self.net_g(self.lq) + + else: + self.output = self.net_g(self.lq, self.ratio) + # illumination correction + if self.correct: + self.output = self.corrector(self.output, self.gt) + self.net_g.train() + + self.output = self.output[:, :, :h, :w] + self.lq = self.lq[:, :, :h, :w] + self.gt = self.gt[:, :, :h, :w] + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + @property + def calculate_metric_in_batch(self): + if hasattr(self, '_calculate_metric_in_batch'): + return self._calculate_metric_in_batch + + ## init + self._calculate_metric_in_batch = False + if self.opt['val'].get('calculate_metric_in_batch', False) is True: + self._calculate_metric_in_batch = True + return self._calculate_metric_in_batch + keys = filter(lambda x: x.startswith('val'), list(self.opt['datasets'].keys())) + for key in keys: + if self.opt['datasets'][key].get('batch_size_per_gpu', 1) > 1: + self._calculate_metric_in_batch = True + return self._calculate_metric_in_batch + return self._calculate_metric_in_batch + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + use_pbar = self.opt['val'].get('pbar', True) + + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + self._initialize_best_metric_results(dataset_name) + + if with_metrics: + self.metric_results = {metric: 0 for metric in self.metric_results} + + metric_data = dict() + if use_pbar: + pbar = tqdm(total=len(dataloader), unit='image') + + if self.calculate_metric_in_batch: + count = 0 + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals(self.metric_in_srgb, save_img) + if not self.calculate_metric_in_batch: + sr_img = tensor2img([visuals['result']]) + metric_data['img'] = sr_img + gt_img = tensor2img([visuals['gt']]) + metric_data['img2'] = gt_img + else: + metric_data['img'] = visuals['result'] + metric_data['img2'] = visuals['gt'] + count += visuals['gt'].shape[0] + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + del self.ccm + del self.wb + torch.cuda.empty_cache() + + psnr = None + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + if self.calculate_metric_in_batch and not opt_['type'].endswith('_pt'): + opt_['type'] = opt_['type'] + '_pt' + metric = calculate_metric(metric_data, opt_) + if self.calculate_metric_in_batch: + metric = torch.sum(metric) + self.metric_results[name] += metric + if name == 'psnr': + psnr = metric + if use_pbar: + pbar.update(1) + pbar.set_description(f'Test {img_name}') + + if save_img: + if not self.calculate_metric_in_batch: + if not self.metric_in_srgb: + sr_img = tensor2img([visuals['result_srgb']]) + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.jpg') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.jpg') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.jpg') + imwrite(sr_img, save_img_path) + else: + if not self.metric_in_srgb: + sr_imgs = tensor2img(visuals['result_srgb']) + else: + sr_imgs = tensor2img(visuals['result']) + if len(sr_imgs.shape) == 3: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.jpg') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}_{psnr:.4f}.jpg') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}_{psnr:.4f}.jpg') + imwrite(sr_imgs, save_img_path) + else: + raise NotImplementedError() + + + if use_pbar: + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + if not self.calculate_metric_in_batch: + self.metric_results[metric] /= (idx + 1) + else: + self.metric_results[metric] /= count + self.metric_results[metric] = self.metric_results[metric].item() + # update the best metric result + self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) + + def get_current_visuals(self, isp=True, save_img=False): + out_dict = OrderedDict() + if isp: + out_dict['lq'] = raw2rgb_torch(self.lq.detach(), self.wb, self.ccm, self.CRF, batch=True) + out_dict['result'] = raw2rgb_torch(self.output.detach(), self.wb, self.ccm, self.CRF, batch=True) + out_dict['gt'] = raw2rgb_torch(self.gt.detach(), self.wb, self.ccm, self.CRF, batch=True) + else: + out_dict['lq'] = self.lq.detach() + out_dict['result'] = self.output.detach() + out_dict['gt'] = self.gt.detach() + if save_img: + out_dict['result_srgb'] = raw2rgb_torch(self.output.detach(), self.wb, self.ccm, self.CRF, batch=True) + if not self.calculate_metric_in_batch: + out_dict['result_srgb'] = out_dict['result_srgb'].cpu() + if not self.calculate_metric_in_batch: + out_dict['lq'] = out_dict['lq'].cpu() + out_dict['result'] = out_dict['result'].cpu() + out_dict['gt'] = out_dict['gt'].cpu() + return out_dict + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) + + + + + + +@MODEL_REGISTRY.register() +class RAWDenoiserModel(BaseModel): + + def __init__(self, opt): + super(RAWDenoiserModel, self).__init__(opt) + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + self.blend_merten = BlendMertens(contrast_weight=1.0, saturation_weight=1.0, exposure_weight=1.0, clip=True) + self.noise_gen = NoiseGenerator(**self.opt['noise_g']) + + network_d = build_network(opt['network_d']) + load_network(network_d, opt['network_d_path']) + self.mapnet = network_d.to(self.device) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + if self.opt.get('CRF_path', None) is not None: + self.CRF = load_CRF(self.opt['CRF_path']) + else: + self.CRF = None + + self.correct = self.opt['val'].get('illumination_correct', False) + if self.correct: + self.corrector = IlluminanceCorrect() + self.metric_in_srgb = self.opt.get('metric_in_srgb', False) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('srgb_opt'): + self.post_process = lambda x, wb, ccm: raw2rgb_torch_grad(x, wb, ccm, self.CRF) + else: + self.post_process = lambda x, wb, ccm: x + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + learnable_layers = train_opt.get('learnable_layers', None) + learnable_keys = train_opt.get('learnable_keys', None) + optim_params = [] + if learnable_layers is None and learnable_keys is None: + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + else: + if learnable_keys is not None: + logger = get_root_logger() + logger.info(f'Using \'learnable_keys\' for query training paarameters ...') + for k, v in self.net_g.named_parameters(): + for l_key in learnable_keys: + if l_key in k: + optim_params.append(v) + break + assert len(optim_params) > 0 + if learnable_layers is not None: + logger = get_root_logger() + logger.info(f'Using \'learnable_layers\' for query training paarameters ...') + for layer in learnable_layers: + if hasattr(self.net_g, layer): + optim_params.extend(list(eval(f'self.net_g.{layer}').parameters())) + else: + logger = get_root_logger() + logger.error(f'Layer {layer} is not in {self.net_g.__name__}.') + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def feed_data(self, data): + + self.intact = data['intact'].to(self.device) + self.lq_clean = data['lq_clean'].to(self.device) + self.ccm = data['ccm'].to(self.device) + self.wb = data['wb'].to(self.device) + self.ratio = data['ratio'].to(self.device) + self.ratio_all = data['ratio1'].to(self.device) + + mexp_lq = data['gt'].to(self.device) + fused_raw = torch.clamp(exposure_fusion_from_raw(mexp_lq, self.wb, self.ccm, self.blend_merten), 1e-8) + self.ratio_all = self.ratio_all.unsqueeze(1).unsqueeze(1).unsqueeze(1) + # add noise + lq_im_patch = torch.clamp(self.lq_clean , min=0) * (16383 - 512) / self.ratio_all + im_patch, noise1 = self.noise_gen(lq_im_patch) + lq_im_patch = sum_img_and_noise(im_patch, noise1) / (16383 - 512) * self.ratio_all + + with torch.no_grad(): + ratiomap = self.mapnet(lq_im_patch) + self.lq = lq_im_patch / (ratiomap + 1e-8) + self.ratiomap = self.ratio_all / (ratiomap + 1e-8) + self.gt = fused_raw + self.lq, self.gt = self.lq.clip(0, 1), self.gt.clip(0, 1) + + def optimize_parameters(self, current_iter): + self.optimizer_g.zero_grad() + index = self.opt['network_g'] + net = index['type'] + + if str(net) == 'UNetArch' or str(net) == 'Restormer': + self.output = self.net_g(self.lq) + else: + self.output = self.net_g(self.lq, self.ratiomap) + + self.output = self.post_process(self.output, self.wb, self.ccm) + self.gt = self.post_process(self.gt, self.wb, self.ccm) + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.cri_pix: + + l_pix = self.cri_pix(self.output, self.gt) + l_total += l_pix + loss_dict['l_pix'] = l_pix + # perceptual loss + if self.cri_perceptual: + l_percep, l_style = self.cri_perceptual(self.output, self.gt) + if l_percep is not None: + l_total += l_percep + loss_dict['l_percep'] = l_percep + if l_style is not None: + l_total += l_style + loss_dict['l_style'] = l_style + + l_total.backward() + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def test(self): + # padding + h, w = self.lq.shape[2:] + pad_h = 16 - (h % 16) if h % 16 != 0 else 0 + pad_w = 16 - (w % 16) if w % 16 != 0 else 0 + + self.gt = self.gt.squeeze(0) + self.lq = nn.functional.pad(self.lq, [0, pad_w, 0, pad_h], mode='replicate') + self.gt = nn.functional.pad(self.gt, [0, pad_w, 0, pad_h], mode='replicate') + + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + # illumination correction + if self.correct: + self.output = self.corrector(self.output, self.gt) + else: + self.net_g.eval() + with torch.no_grad(): + index = self.opt['network_g'] + net = index['type'] + if str(net) == 'UNetArch': + self.output = self.net_g(self.lq) + + else: + self.output = self.net_g(self.lq, self.ratio) + # illumination correction + if self.correct: + self.output = self.corrector(self.output, self.gt) + self.net_g.train() + + self.output = self.output[:, :, :h, :w] + self.lq = self.lq[:, :, :h, :w] + self.gt = self.gt[:, :, :h, :w] + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + @property + def calculate_metric_in_batch(self): + if hasattr(self, '_calculate_metric_in_batch'): + return self._calculate_metric_in_batch + + ## init + self._calculate_metric_in_batch = False + if self.opt['val'].get('calculate_metric_in_batch', False) is True: + self._calculate_metric_in_batch = True + return self._calculate_metric_in_batch + keys = filter(lambda x: x.startswith('val'), list(self.opt['datasets'].keys())) + for key in keys: + if self.opt['datasets'][key].get('batch_size_per_gpu', 1) > 1: + self._calculate_metric_in_batch = True + return self._calculate_metric_in_batch + return self._calculate_metric_in_batch + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + use_pbar = self.opt['val'].get('pbar', True) + + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + # initialize the best metric results for each dataset_name (supporting multiple validation datasets) + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + if with_metrics: + self.metric_results = {metric: 0 for metric in self.metric_results} + + metric_data = dict() + if use_pbar: + pbar = tqdm(total=len(dataloader), unit='image') + + if self.calculate_metric_in_batch: + count = 0 + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals(self.metric_in_srgb, save_img) + if not self.calculate_metric_in_batch: + sr_img = tensor2img([visuals['result']]) + metric_data['img'] = sr_img + gt_img = tensor2img([visuals['gt']]) + metric_data['img2'] = gt_img + else: + metric_data['img'] = visuals['result'] + metric_data['img2'] = visuals['gt'] + count += visuals['gt'].shape[0] + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + del self.ccm + del self.wb + torch.cuda.empty_cache() + + psnr = None + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + if self.calculate_metric_in_batch and not opt_['type'].endswith('_pt'): + opt_['type'] = opt_['type'] + '_pt' + metric = calculate_metric(metric_data, opt_) + if self.calculate_metric_in_batch: + metric = torch.sum(metric) + self.metric_results[name] += metric + if name == 'psnr': + psnr = metric + if use_pbar: + pbar.update(1) + pbar.set_description(f'Test {img_name}') + + if save_img: + if not self.calculate_metric_in_batch: + if not self.metric_in_srgb: + sr_img = tensor2img([visuals['result_srgb']]) + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.jpg') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.jpg') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.jpg') + imwrite(sr_img, save_img_path) + else: + if not self.metric_in_srgb: + sr_imgs = tensor2img(visuals['result_srgb']) + else: + sr_imgs = tensor2img(visuals['result']) + if len(sr_imgs.shape) == 3: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.jpg') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}_{psnr:.4f}.jpg') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}_{psnr:.4f}.jpg') + imwrite(sr_imgs, save_img_path) + else: + raise NotImplementedError() + + + if use_pbar: + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + if not self.calculate_metric_in_batch: + self.metric_results[metric] /= (idx + 1) + else: + self.metric_results[metric] /= count + self.metric_results[metric] = self.metric_results[metric].item() + # update the best metric result + self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) + + def get_current_visuals(self, isp=True, save_img=False): + out_dict = OrderedDict() + if isp: + out_dict['lq'] = raw2rgb_torch(self.lq.detach(), self.wb, self.ccm, self.CRF, batch=True) + out_dict['result'] = raw2rgb_torch(self.output.detach(), self.wb, self.ccm, self.CRF, batch=True) + out_dict['gt'] = raw2rgb_torch(self.gt.detach(), self.wb, self.ccm, self.CRF, batch=True) + else: + out_dict['lq'] = self.lq.detach() + out_dict['result'] = self.output.detach() + out_dict['gt'] = self.gt.detach() + if save_img: + out_dict['result_srgb'] = raw2rgb_torch(self.output.detach(), self.wb, self.ccm, self.CRF, batch=True) + if not self.calculate_metric_in_batch: + out_dict['result_srgb'] = out_dict['result_srgb'].cpu() + if not self.calculate_metric_in_batch: + out_dict['lq'] = out_dict['lq'].cpu() + out_dict['result'] = out_dict['result'].cpu() + out_dict['gt'] = out_dict['gt'].cpu() + return out_dict + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) \ No newline at end of file diff --git a/ultraled/ops/__init__.py b/ultraled/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ultraled/ops/__pycache__/__init__.cpython-38.pyc b/ultraled/ops/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..1611f16 Binary files /dev/null and b/ultraled/ops/__pycache__/__init__.cpython-38.pyc differ diff --git a/ultraled/ops/dcn/__init__.py b/ultraled/ops/dcn/__init__.py new file mode 100644 index 0000000..32e3592 --- /dev/null +++ b/ultraled/ops/dcn/__init__.py @@ -0,0 +1,7 @@ +from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, + modulated_deform_conv) + +__all__ = [ + 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', + 'modulated_deform_conv' +] diff --git a/ultraled/ops/dcn/deform_conv.py b/ultraled/ops/dcn/deform_conv.py new file mode 100644 index 0000000..6268ca8 --- /dev/null +++ b/ultraled/ops/dcn/deform_conv.py @@ -0,0 +1,379 @@ +import math +import os +import torch +from torch import nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn import functional as F +from torch.nn.modules.utils import _pair, _single + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + deform_conv_ext = load( + 'deform_conv', + sources=[ + os.path.join(module_path, 'src', 'deform_conv_ext.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'), + ], + ) +else: + try: + from . import deform_conv_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class DeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=64): + if input is not None and input.dim() != 4: + raise ValueError(f'Expected 4D tensor as input, got {input.dim()}D tensor instead.') + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + deform_conv_ext.deform_conv_forward(input, weight, + offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input, + grad_offset, weight, ctx.bufs_[0], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight, + ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], + ctx.padding[1], ctx.padding[0], ctx.dilation[1], + ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1, + cur_im2col_step) + + return (grad_input, grad_offset, grad_weight, None, None, None, None, None) + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError(f'convolution input is too small (output would be {"x".join(map(str, output_size))})') + return output_size + + +class ModulatedDeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output, + ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, + grad_output, weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 + width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply + + +class DeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False): + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, f'in_channels {in_channels} is not divisible by groups {groups}' + assert out_channels % groups == 0, f'out_channels {out_channels} is not divisible by groups {groups}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, x, offset): + # To fix an assert error in deform_conv_cuda.cpp:128 + # input image is smaller than kernel + input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1]) + if input_pad: + pad_h = max(self.kernel_size[0] - x.size(2), 0) + pad_w = max(self.kernel_size[1] - x.size(3), 0) + x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + if input_pad: + out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous() + return out + + +class DeformConvPack(DeformConv): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(DeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + + +class ModulatedDeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) + + +class ModulatedDeformConvPack(ModulatedDeformConv): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_weights() + + def init_weights(self): + super(ModulatedDeformConvPack, self).init_weights() + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) diff --git a/ultraled/ops/dcn/src/deform_conv_cuda.cpp b/ultraled/ops/dcn/src/deform_conv_cuda.cpp new file mode 100644 index 0000000..b465c49 --- /dev/null +++ b/ultraled/ops/dcn/src/deform_conv_cuda.cpp @@ -0,0 +1,685 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) { + TORCH_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} diff --git a/ultraled/ops/dcn/src/deform_conv_cuda_kernel.cu b/ultraled/ops/dcn/src/deform_conv_cuda_kernel.cu new file mode 100644 index 0000000..98752dc --- /dev/null +++ b/ultraled/ops/dcn/src/deform_conv_cuda_kernel.cu @@ -0,0 +1,867 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/ultraled/ops/dcn/src/deform_conv_ext.cpp b/ultraled/ops/dcn/src/deform_conv_ext.cpp new file mode 100644 index 0000000..41c6df6 --- /dev/null +++ b/ultraled/ops/dcn/src/deform_conv_ext.cpp @@ -0,0 +1,164 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#define WITH_CUDA // always use cuda +#ifdef WITH_CUDA +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); +#endif + +int deform_conv_forward(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_forward_cuda(input, weight, offset, output, columns, + ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, + deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_input(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_input_cuda(input, offset, gradOutput, + gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, + dilationW, dilationH, group, deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_parameters( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_parameters_cuda(input, offset, gradOutput, + gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, + dilationH, group, deformable_group, scale, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_forward(input, weight, bias, ones, + offset, mask, output, columns, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + deformable_group, with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_backward(input, weight, bias, ones, + offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, + grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, + with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_conv_forward", &deform_conv_forward, + "deform forward"); + m.def("deform_conv_backward_input", &deform_conv_backward_input, + "deform_conv_backward_input"); + m.def("deform_conv_backward_parameters", + &deform_conv_backward_parameters, + "deform_conv_backward_parameters"); + m.def("modulated_deform_conv_forward", + &modulated_deform_conv_forward, + "modulated deform conv forward"); + m.def("modulated_deform_conv_backward", + &modulated_deform_conv_backward, + "modulated deform conv backward"); +} diff --git a/ultraled/ops/fused_act/__init__.py b/ultraled/ops/fused_act/__init__.py new file mode 100644 index 0000000..241dc07 --- /dev/null +++ b/ultraled/ops/fused_act/__init__.py @@ -0,0 +1,3 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu + +__all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] diff --git a/ultraled/ops/fused_act/fused_act.py b/ultraled/ops/fused_act/fused_act.py new file mode 100644 index 0000000..88edc44 --- /dev/null +++ b/ultraled/ops/fused_act/fused_act.py @@ -0,0 +1,95 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 + +import os +import torch +from torch import nn +from torch.autograd import Function + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + fused_act_ext = load( + 'fused', + sources=[ + os.path.join(module_path, 'src', 'fused_bias_act.cpp'), + os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), + ], + ) +else: + try: + from . import fused_act_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class FusedLeakyReLUFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, + ctx.scale) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/ultraled/ops/fused_act/src/fused_bias_act.cpp b/ultraled/ops/fused_act/src/fused_bias_act.cpp new file mode 100644 index 0000000..85ed0a7 --- /dev/null +++ b/ultraled/ops/fused_act/src/fused_bias_act.cpp @@ -0,0 +1,26 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} diff --git a/ultraled/ops/fused_act/src/fused_bias_act_kernel.cu b/ultraled/ops/fused_act/src/fused_bias_act_kernel.cu new file mode 100644 index 0000000..54c7ff5 --- /dev/null +++ b/ultraled/ops/fused_act/src/fused_bias_act_kernel.cu @@ -0,0 +1,100 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} diff --git a/ultraled/ops/upfirdn2d/__init__.py b/ultraled/ops/upfirdn2d/__init__.py new file mode 100644 index 0000000..397e85b --- /dev/null +++ b/ultraled/ops/upfirdn2d/__init__.py @@ -0,0 +1,3 @@ +from .upfirdn2d import upfirdn2d + +__all__ = ['upfirdn2d'] diff --git a/ultraled/ops/upfirdn2d/src/upfirdn2d.cpp b/ultraled/ops/upfirdn2d/src/upfirdn2d.cpp new file mode 100644 index 0000000..43d0b67 --- /dev/null +++ b/ultraled/ops/upfirdn2d/src/upfirdn2d.cpp @@ -0,0 +1,24 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} diff --git a/ultraled/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/ultraled/ops/upfirdn2d/src/upfirdn2d_kernel.cu new file mode 100644 index 0000000..8870063 --- /dev/null +++ b/ultraled/ops/upfirdn2d/src/upfirdn2d_kernel.cu @@ -0,0 +1,370 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} diff --git a/ultraled/ops/upfirdn2d/upfirdn2d.py b/ultraled/ops/upfirdn2d/upfirdn2d.py new file mode 100644 index 0000000..d6122d5 --- /dev/null +++ b/ultraled/ops/upfirdn2d/upfirdn2d.py @@ -0,0 +1,192 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 + +import os +import torch +from torch.autograd import Function +from torch.nn import functional as F + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + upfirdn2d_ext = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'src', 'upfirdn2d.cpp'), + os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), + ], + ) +else: + try: + from . import upfirdn2d_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_ext.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_ext.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], + # ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + _, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == 'cpu': + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + else: + out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/ultraled/test.py b/ultraled/test.py new file mode 100644 index 0000000..a9bdb85 --- /dev/null +++ b/ultraled/test.py @@ -0,0 +1,45 @@ +import logging +import torch +from os import path as osp + +from ultraled.data import build_dataloader, build_dataset +from ultraled.models import build_model +from ultraled.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs +from ultraled.utils.options import dict2str, parse_options + + +def test_pipeline(root_path): + # parse options, set distributed setting, set ramdom seed + opt, _ = parse_options(root_path, is_train=False) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # mkdir and initialize loggers + make_exp_dirs(opt) + log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # create test dataset and dataloader + test_loaders = [] + for _, dataset_opt in sorted(opt['datasets'].items()): + test_set = build_dataset(dataset_opt) + test_loader = build_dataloader( + test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") + test_loaders.append(test_loader) + + # create model + model = build_model(opt) + + for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info(f'Testing {test_set_name}...') + model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + test_pipeline(root_path) diff --git a/ultraled/train.py b/ultraled/train.py new file mode 100644 index 0000000..df884a3 --- /dev/null +++ b/ultraled/train.py @@ -0,0 +1,222 @@ +import datetime +import logging +import math +import time +import torch +from os import path as osp +import sys + + +from ultraled.data import build_dataloader, build_dataset +from ultraled.data.data_sampler import EnlargedSampler +from ultraled.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher +from ultraled.models import build_model +from ultraled.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str, + init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir) +from ultraled.utils.options import copy_opt_file, dict2str, parse_options + + +def init_tb_loggers(opt): + # initialize wandb logger before tensorboard logger to allow proper sync + if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') + is not None) and ('debug' not in opt['name']): + assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb') + init_wandb_logger(opt) + tb_logger = None + if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']: + tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name'])) + return tb_logger + + +def create_train_val_dataloader(opt, logger): + # create train and val dataloaders + train_loader, val_loaders = None, [] + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) + train_set = build_dataset(dataset_opt) + train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio) + train_loader = build_dataloader( + train_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=train_sampler, + seed=opt['manual_seed']) + + num_iter_per_epoch = math.ceil( + len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) + total_iters = int(opt['train']['total_iter']) + total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) + logger.info('Training statistics:' + f'\n\tNumber of train images: {len(train_set)}' + f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' + f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' + f'\n\tWorld size (gpu number): {opt["world_size"]}' + f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' + f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') + elif phase.split('_')[0] == 'val': + val_set = build_dataset(dataset_opt) + val_loader = build_dataloader( + val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}') + val_loaders.append(val_loader) + else: + raise ValueError(f'Dataset phase {phase} is not recognized.') + + return train_loader, train_sampler, val_loaders, total_epochs, total_iters + + +def load_resume_state(opt): + resume_state_path = None + if opt['auto_resume']: + state_path = osp.join('experiments', opt['name'], 'training_states') + if osp.isdir(state_path): + states = list(scandir(state_path, suffix='state', recursive=False, full_path=False)) + if len(states) != 0: + states = [float(v.split('.state')[0]) for v in states] + resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') + opt['path']['resume_state'] = resume_state_path + else: + if opt['path'].get('resume_state'): + resume_state_path = opt['path']['resume_state'] + + if resume_state_path is None: + resume_state = None + else: + device_id = torch.cuda.current_device() + resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) + check_resume(opt, resume_state['iter']) + return resume_state + + +def train_pipeline(root_path): + # parse options, set distributed setting, set random seed + opt, args = parse_options(root_path, is_train=True) + opt['root_path'] = root_path + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # load resume states if necessary + resume_state = load_resume_state(opt) + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0: + mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name'])) + + # copy the yml file to the experiment root + copy_opt_file(args.opt, opt['path']['experiments_root']) + + # WARNING: should not use get_root_logger in the above codes, including the called functions + # Otherwise the logger will not be properly initialized + log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + # initialize wandb and tb loggers + tb_logger = init_tb_loggers(opt) + + # create train and validation dataloaders + result = create_train_val_dataloader(opt, logger) + train_loader, train_sampler, val_loaders, total_epochs, total_iters = result + + # create model + model = build_model(opt) + if resume_state: # resume training + model.resume_training(resume_state) # handle optimizers and schedulers + logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.") + start_epoch = resume_state['epoch'] + current_iter = resume_state['iter'] + else: + start_epoch = 0 + current_iter = 0 + + # create message logger (formatted outputs) + msg_logger = MessageLogger(opt, current_iter, tb_logger) + + # dataloader prefetcher + prefetch_mode = opt['datasets']['train'].get('prefetch_mode') + if prefetch_mode is None or prefetch_mode == 'cpu': + prefetcher = CPUPrefetcher(train_loader) + elif prefetch_mode == 'cuda': + prefetcher = CUDAPrefetcher(train_loader, opt) + logger.info(f'Use {prefetch_mode} prefetch dataloader') + if opt['datasets']['train'].get('pin_memory') is not True: + raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') + else: + raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.") + + # training + logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}') + data_timer, iter_timer = AvgTimer(), AvgTimer() + start_time = time.time() + + for epoch in range(start_epoch, total_epochs + 1): + train_sampler.set_epoch(epoch) + prefetcher.reset() + train_data = prefetcher.next() + + while train_data is not None: + data_timer.record() + + current_iter += 1 + if current_iter > total_iters: + break + # update learning rate + model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) + # training + model.feed_data(train_data) + model.optimize_parameters(current_iter) + iter_timer.record() + if current_iter == 1: + # reset start time in msg_logger for more accurate eta_time + # not work in resume mode + msg_logger.reset_start_time() + # log + if current_iter % opt['logger']['print_freq'] == 0: + log_vars = {'epoch': epoch, 'iter': current_iter} + log_vars.update({'lrs': model.get_current_learning_rate()}) + log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()}) + log_vars.update(model.get_current_log()) + msg_logger(log_vars) + + # save models and training states + if current_iter % opt['logger']['save_checkpoint_freq'] == 0: + # if epoch%2 == 0: + # logger.info('Saving models and training states.') + # model.save(epoch, current_iter) + + logger.info('Saving models and training states.') + model.save(epoch, current_iter) + + # validation + if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0): + if len(val_loaders) > 1: + logger.warning('Multiple validation datasets are *only* supported by SRModel.') + for val_loader in val_loaders: + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + # pass + + data_timer.start() + iter_timer.start() + train_data = prefetcher.next() + # end of iter + + # end of epoch + + consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) + logger.info(f'End of training. Time consumed: {consumed_time}') + logger.info('Save the latest model.') + model.save(epoch=-1, current_iter=-1) # -1 stands for the latest + if opt.get('val') is not None: + for val_loader in val_loaders: + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + if tb_logger: + tb_logger.close() + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/ultraled/utils/__init__.py b/ultraled/utils/__init__.py new file mode 100644 index 0000000..1a29958 --- /dev/null +++ b/ultraled/utils/__init__.py @@ -0,0 +1,53 @@ +from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb +from .diffjpeg import DiffJPEG +from .file_client import FileClient +from .img_process_util import USMSharp, usm_sharp +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img +from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger +from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt +from .common import AverageMeter +from .process import raw2rgb_postprocess, load_CRF, raw2rgb_v2, raw2rgb_torch, raw2rgb_torch_grad + +__all__ = [ + # color_util.py + 'bgr2ycbcr', + 'rgb2ycbcr', + 'rgb2ycbcr_pt', + 'ycbcr2bgr', + 'ycbcr2rgb', + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'AvgTimer', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # misc.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'check_resume', + 'sizeof_fmt', + # diffjpeg + 'DiffJPEG', + # img_process_util + 'USMSharp', + 'usm_sharp', + # average meter + 'AverageMeter', + 'raw2rgb_postprocess', + 'raw2rgb_v2', + 'raw2rgb_torch', + 'raw2rgb_torch_grad', + 'load_CRF' +] diff --git a/ultraled/utils/__pycache__/__init__.cpython-38.pyc b/ultraled/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..e360cae Binary files /dev/null and b/ultraled/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/color_util.cpython-38.pyc b/ultraled/utils/__pycache__/color_util.cpython-38.pyc new file mode 100644 index 0000000..b8782c4 Binary files /dev/null and b/ultraled/utils/__pycache__/color_util.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/common.cpython-38.pyc b/ultraled/utils/__pycache__/common.cpython-38.pyc new file mode 100644 index 0000000..400636f Binary files /dev/null and b/ultraled/utils/__pycache__/common.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/diffjpeg.cpython-38.pyc b/ultraled/utils/__pycache__/diffjpeg.cpython-38.pyc new file mode 100644 index 0000000..20ce974 Binary files /dev/null and b/ultraled/utils/__pycache__/diffjpeg.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/dist_util.cpython-38.pyc b/ultraled/utils/__pycache__/dist_util.cpython-38.pyc new file mode 100644 index 0000000..9c38779 Binary files /dev/null and b/ultraled/utils/__pycache__/dist_util.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/file_client.cpython-38.pyc b/ultraled/utils/__pycache__/file_client.cpython-38.pyc new file mode 100644 index 0000000..e287755 Binary files /dev/null and b/ultraled/utils/__pycache__/file_client.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/img_process_util.cpython-38.pyc b/ultraled/utils/__pycache__/img_process_util.cpython-38.pyc new file mode 100644 index 0000000..d24a41d Binary files /dev/null and b/ultraled/utils/__pycache__/img_process_util.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/img_util.cpython-38.pyc b/ultraled/utils/__pycache__/img_util.cpython-38.pyc new file mode 100644 index 0000000..d21eea1 Binary files /dev/null and b/ultraled/utils/__pycache__/img_util.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/logger.cpython-38.pyc b/ultraled/utils/__pycache__/logger.cpython-38.pyc new file mode 100644 index 0000000..39df267 Binary files /dev/null and b/ultraled/utils/__pycache__/logger.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/matlab_functions.cpython-38.pyc b/ultraled/utils/__pycache__/matlab_functions.cpython-38.pyc new file mode 100644 index 0000000..96640ca Binary files /dev/null and b/ultraled/utils/__pycache__/matlab_functions.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/misc.cpython-38.pyc b/ultraled/utils/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000..70f1ca0 Binary files /dev/null and b/ultraled/utils/__pycache__/misc.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/options.cpython-38.pyc b/ultraled/utils/__pycache__/options.cpython-38.pyc new file mode 100644 index 0000000..c7ff1bc Binary files /dev/null and b/ultraled/utils/__pycache__/options.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/process.cpython-38.pyc b/ultraled/utils/__pycache__/process.cpython-38.pyc new file mode 100644 index 0000000..961fbe1 Binary files /dev/null and b/ultraled/utils/__pycache__/process.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/registry.cpython-38.pyc b/ultraled/utils/__pycache__/registry.cpython-38.pyc new file mode 100644 index 0000000..37dbf5b Binary files /dev/null and b/ultraled/utils/__pycache__/registry.cpython-38.pyc differ diff --git a/ultraled/utils/__pycache__/torchinterp1d.cpython-38.pyc b/ultraled/utils/__pycache__/torchinterp1d.cpython-38.pyc new file mode 100644 index 0000000..262f643 Binary files /dev/null and b/ultraled/utils/__pycache__/torchinterp1d.cpython-38.pyc differ diff --git a/ultraled/utils/color_util.py b/ultraled/utils/color_util.py new file mode 100644 index 0000000..4740d5c --- /dev/null +++ b/ultraled/utils/color_util.py @@ -0,0 +1,208 @@ +import numpy as np +import torch + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + conversion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace conversion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) + + +def rgb2ycbcr_pt(img, y_only=False): + """Convert RGB images to YCbCr images (PyTorch version). + + It implements the ITU-R BT.601 conversion for standard-definition television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + Args: + img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float. + """ + if y_only: + weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 + else: + weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img) + bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias + + out_img = out_img / 255. + return out_img diff --git a/ultraled/utils/common.py b/ultraled/utils/common.py new file mode 100644 index 0000000..aef0606 --- /dev/null +++ b/ultraled/utils/common.py @@ -0,0 +1,17 @@ +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/ultraled/utils/diffjpeg.py b/ultraled/utils/diffjpeg.py new file mode 100644 index 0000000..32ca9f6 --- /dev/null +++ b/ultraled/utils/diffjpeg.py @@ -0,0 +1,515 @@ +""" +Modified from https://github.com/mlomnitz/DiffJPEG + +For images not divisible by 8 +https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343 +""" +import itertools +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +# ------------------------ utils ------------------------# +y_table = np.array( + [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56], + [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92], + [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], + dtype=np.float32).T +y_table = nn.Parameter(torch.from_numpy(y_table)) +c_table = np.empty((8, 8), dtype=np.float32) +c_table.fill(99) +c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T +c_table = nn.Parameter(torch.from_numpy(c_table)) + + +def diff_round(x): + """ Differentiable rounding function + """ + return torch.round(x) + (x - torch.round(x))**3 + + +def quality_to_factor(quality): + """ Calculate factor corresponding to quality + + Args: + quality(float): Quality for jpeg compression. + + Returns: + float: Compression factor. + """ + if quality < 50: + quality = 5000. / quality + else: + quality = 200. - quality * 2 + return quality / 100. + + +# ------------------------ compression ------------------------# +class RGB2YCbCrJpeg(nn.Module): + """ Converts RGB image to YCbCr + """ + + def __init__(self): + super(RGB2YCbCrJpeg, self).__init__() + matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]], + dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0., 128., 128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(Tensor): batch x 3 x height x width + + Returns: + Tensor: batch x height x width x 3 + """ + image = image.permute(0, 2, 3, 1) + result = torch.tensordot(image, self.matrix, dims=1) + self.shift + return result.view(image.shape) + + +class ChromaSubsampling(nn.Module): + """ Chroma subsampling on CbCr channels + """ + + def __init__(self): + super(ChromaSubsampling, self).__init__() + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + y(tensor): batch x height x width + cb(tensor): batch x height/2 x width/2 + cr(tensor): batch x height/2 x width/2 + """ + image_2 = image.permute(0, 3, 1, 2).clone() + cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cb = cb.permute(0, 2, 3, 1) + cr = cr.permute(0, 2, 3, 1) + return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) + + +class BlockSplitting(nn.Module): + """ Splitting image into patches + """ + + def __init__(self): + super(BlockSplitting, self).__init__() + self.k = 8 + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x h*w/64 x h x w + """ + height, _ = image.shape[1:3] + batch_size = image.shape[0] + image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) + + +class DCT8x8(nn.Module): + """ Discrete Cosine Transformation + """ + + def __init__(self): + super(DCT8x8, self).__init__() + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16) + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image - 128 + result = self.scale * torch.tensordot(image, self.tensor, dims=2) + result.view(image.shape) + return result + + +class YQuantize(nn.Module): + """ JPEG Quantization for Y channel + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(YQuantize, self).__init__() + self.rounding = rounding + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CQuantize(nn.Module): + """ JPEG Quantization for CbCr channels + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(CQuantize, self).__init__() + self.rounding = rounding + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CompressJpeg(nn.Module): + """Full JPEG compression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(CompressJpeg, self).__init__() + self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling()) + self.l2 = nn.Sequential(BlockSplitting(), DCT8x8()) + self.c_quantize = CQuantize(rounding=rounding) + self.y_quantize = YQuantize(rounding=rounding) + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x 3 x height x width + + Returns: + dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8. + """ + y, cb, cr = self.l1(image * 255) + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + comp = self.l2(components[k]) + if k in ('cb', 'cr'): + comp = self.c_quantize(comp, factor=factor) + else: + comp = self.y_quantize(comp, factor=factor) + + components[k] = comp + + return components['y'], components['cb'], components['cr'] + + +# ------------------------ decompression ------------------------# + + +class YDequantize(nn.Module): + """Dequantize Y channel + """ + + def __init__(self): + super(YDequantize, self).__init__() + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class CDequantize(nn.Module): + """Dequantize CbCr channel + """ + + def __init__(self): + super(CDequantize, self).__init__() + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class iDCT8x8(nn.Module): + """Inverse discrete Cosine Transformation + """ + + def __init__(self): + super(iDCT8x8, self).__init__() + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image * self.alpha + result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 + result.view(image.shape) + return result + + +class BlockMerging(nn.Module): + """Merge patches into image + """ + + def __init__(self): + super(BlockMerging, self).__init__() + + def forward(self, patches, height, width): + """ + Args: + patches(tensor) batch x height*width/64, height x width + height(int) + width(int) + + Returns: + Tensor: batch x height x width + """ + k = 8 + batch_size = patches.shape[0] + image_reshaped = patches.view(batch_size, height // k, width // k, k, k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, height, width) + + +class ChromaUpsampling(nn.Module): + """Upsample chroma layers + """ + + def __init__(self): + super(ChromaUpsampling, self).__init__() + + def forward(self, y, cb, cr): + """ + Args: + y(tensor): y channel image + cb(tensor): cb channel + cr(tensor): cr channel + + Returns: + Tensor: batch x height x width x 3 + """ + + def repeat(x, k=2): + height, width = x.shape[1:3] + x = x.unsqueeze(-1) + x = x.repeat(1, 1, k, k) + x = x.view(-1, height * k, width * k) + return x + + cb = repeat(cb) + cr = repeat(cr) + return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) + + +class YCbCr2RGBJpeg(nn.Module): + """Converts YCbCr image to RGB JPEG + """ + + def __init__(self): + super(YCbCr2RGBJpeg, self).__init__() + + matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + Tensor: batch x 3 x height x width + """ + result = torch.tensordot(image + self.shift, self.matrix, dims=1) + return result.view(image.shape).permute(0, 3, 1, 2) + + +class DeCompressJpeg(nn.Module): + """Full JPEG decompression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(DeCompressJpeg, self).__init__() + self.c_dequantize = CDequantize() + self.y_dequantize = YDequantize() + self.idct = iDCT8x8() + self.merging = BlockMerging() + self.chroma = ChromaUpsampling() + self.colors = YCbCr2RGBJpeg() + + def forward(self, y, cb, cr, imgh, imgw, factor=1): + """ + Args: + compressed(dict(tensor)): batch x h*w/64 x 8 x 8 + imgh(int) + imgw(int) + factor(float) + + Returns: + Tensor: batch x 3 x height x width + """ + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + if k in ('cb', 'cr'): + comp = self.c_dequantize(components[k], factor=factor) + height, width = int(imgh / 2), int(imgw / 2) + else: + comp = self.y_dequantize(components[k], factor=factor) + height, width = imgh, imgw + comp = self.idct(comp) + components[k] = self.merging(comp, height, width) + # + image = self.chroma(components['y'], components['cb'], components['cr']) + image = self.colors(image) + + image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image)) + return image / 255 + + +# ------------------------ main DiffJPEG ------------------------ # + + +class DiffJPEG(nn.Module): + """This JPEG algorithm result is slightly different from cv2. + DiffJPEG supports batch processing. + + Args: + differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round + """ + + def __init__(self, differentiable=True): + super(DiffJPEG, self).__init__() + if differentiable: + rounding = diff_round + else: + rounding = torch.round + + self.compress = CompressJpeg(rounding=rounding) + self.decompress = DeCompressJpeg(rounding=rounding) + + def forward(self, x, quality): + """ + Args: + x (Tensor): Input image, bchw, rgb, [0, 1] + quality(float): Quality factor for jpeg compression scheme. + """ + factor = quality + if isinstance(factor, (int, float)): + factor = quality_to_factor(factor) + else: + for i in range(factor.size(0)): + factor[i] = quality_to_factor(factor[i]) + h, w = x.size()[-2:] + h_pad, w_pad = 0, 0 + # why should use 16 + if h % 16 != 0: + h_pad = 16 - h % 16 + if w % 16 != 0: + w_pad = 16 - w % 16 + x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0) + + y, cb, cr = self.compress(x, factor=factor) + recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor) + recovered = recovered[:, :, 0:h, 0:w] + return recovered + + +if __name__ == '__main__': + import cv2 + + from ultraled.utils import img2tensor, tensor2img + + img_gt = cv2.imread('test.png') / 255. + + # -------------- cv2 -------------- # + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20] + _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param) + img_lq = np.float32(cv2.imdecode(encimg, 1)) + cv2.imwrite('cv2_JPEG_20.png', img_lq) + + # -------------- DiffJPEG -------------- # + jpeger = DiffJPEG(differentiable=False).cuda() + img_gt = img2tensor(img_gt) + img_gt = torch.stack([img_gt, img_gt]).cuda() + quality = img_gt.new_tensor([20, 40]) + out = jpeger(img_gt, quality=quality) + + cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0])) + cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1])) diff --git a/ultraled/utils/dist_util.py b/ultraled/utils/dist_util.py new file mode 100644 index 0000000..0fab887 --- /dev/null +++ b/ultraled/utils/dist_util.py @@ -0,0 +1,82 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/ultraled/utils/download_util.py b/ultraled/utils/download_util.py new file mode 100644 index 0000000..6adda71 --- /dev/null +++ b/ultraled/utils/download_util.py @@ -0,0 +1,99 @@ +import math +import os +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm +from urllib.parse import urlparse + +from .misc import sizeof_fmt + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + if 'Content-Range' in response_file_size.headers: + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file diff --git a/ultraled/utils/file_client.py b/ultraled/utils/file_client.py new file mode 100644 index 0000000..89d83ab --- /dev/null +++ b/ultraled/utils/file_client.py @@ -0,0 +1,167 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing different lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/ultraled/utils/flow_util.py b/ultraled/utils/flow_util.py new file mode 100644 index 0000000..3d7180b --- /dev/null +++ b/ultraled/utils/flow_util.py @@ -0,0 +1,170 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 +import cv2 +import numpy as np +import os + + +def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): + """Read an optical flow map. + + Args: + flow_path (ndarray or str): Flow path. + quantize (bool): whether to read quantized pair, if set to True, + remaining args will be passed to :func:`dequantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + + Returns: + ndarray: Optical flow represented as a (h, w, 2) numpy array + """ + if quantize: + assert concat_axis in [0, 1] + cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) + if cat_flow.ndim != 2: + raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') + assert cat_flow.shape[concat_axis] % 2 == 0 + dx, dy = np.split(cat_flow, 2, axis=concat_axis) + flow = dequantize_flow(dx, dy, *args, **kwargs) + else: + with open(flow_path, 'rb') as f: + try: + header = f.read(4).decode('utf-8') + except Exception: + raise IOError(f'Invalid flow file: {flow_path}') + else: + if header != 'PIEH': + raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH') + + w = np.fromfile(f, np.int32, 1).squeeze() + h = np.fromfile(f, np.int32, 1).squeeze() + flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) + + return flow.astype(np.float32) + + +def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): + """Write optical flow to file. + + If the flow is not quantized, it will be saved as a .flo file losslessly, + otherwise a jpeg image which is lossy but of much smaller size. (dx and dy + will be concatenated horizontally into a single image if quantize is True.) + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + filename (str): Output filepath. + quantize (bool): Whether to quantize the flow and save it to 2 jpeg + images. If set to True, remaining args will be passed to + :func:`quantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + """ + if not quantize: + with open(filename, 'wb') as f: + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + f.flush() + else: + assert concat_axis in [0, 1] + dx, dy = quantize_flow(flow, *args, **kwargs) + dxdy = np.concatenate((dx, dy), axis=concat_axis) + os.makedirs(os.path.dirname(filename), exist_ok=True) + cv2.imwrite(filename, dxdy) + + +def quantize_flow(flow, max_val=0.02, norm=True): + """Quantize flow to [0, 255]. + + After this step, the size of flow will be much smaller, and can be + dumped as jpeg images. + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + max_val (float): Maximum value of flow, values beyond + [-max_val, max_val] will be truncated. + norm (bool): Whether to divide flow values by image width/height. + + Returns: + tuple[ndarray]: Quantized dx and dy. + """ + h, w, _ = flow.shape + dx = flow[..., 0] + dy = flow[..., 1] + if norm: + dx = dx / w # avoid inplace operations + dy = dy / h + # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. + flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] + return tuple(flow_comps) + + +def dequantize_flow(dx, dy, max_val=0.02, denorm=True): + """Recover from quantized flow. + + Args: + dx (ndarray): Quantized dx. + dy (ndarray): Quantized dy. + max_val (float): Maximum value used when quantizing. + denorm (bool): Whether to multiply flow values with width/height. + + Returns: + ndarray: Dequantized flow. + """ + assert dx.shape == dy.shape + assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) + + dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] + + if denorm: + dx *= dx.shape[1] + dy *= dx.shape[0] + flow = np.dstack((dx, dy)) + return flow + + +def quantize(arr, min_val, max_val, levels, dtype=np.int64): + """Quantize an array of (-inf, inf) to [0, levels-1]. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the quantized array. + + Returns: + tuple: Quantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError(f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + arr = np.clip(arr, min_val, max_val) - min_val + quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) + + return quantized_arr + + +def dequantize(arr, min_val, max_val, levels, dtype=np.float64): + """Dequantize an array. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the dequantized array. + + Returns: + tuple: Dequantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError(f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val + + return dequantized_arr diff --git a/ultraled/utils/img_process_util.py b/ultraled/utils/img_process_util.py new file mode 100644 index 0000000..52e02f0 --- /dev/null +++ b/ultraled/utils/img_process_util.py @@ -0,0 +1,83 @@ +import cv2 +import numpy as np +import torch +from torch.nn import functional as F + + +def filter2D(img, kernel): + """PyTorch version of cv2.filter2D + + Args: + img (Tensor): (b, c, h, w) + kernel (Tensor): (b, k, k) + """ + k = kernel.size(-1) + b, c, h, w = img.size() + if k % 2 == 1: + img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') + else: + raise ValueError('Wrong kernel size') + + ph, pw = img.size()[-2:] + + if kernel.size(0) == 1: + # apply the same kernel to all batch images + img = img.view(b * c, 1, ph, pw) + kernel = kernel.view(1, 1, k, k) + return F.conv2d(img, kernel, padding=0).view(b, c, h, w) + else: + img = img.view(1, b * c, ph, pw) + kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) + return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) + + +def usm_sharp(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. + + Input image: I; Blurry image: B. + 1. sharp = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * sharp + (1 - Mask) * I + + + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + sharp = img + weight * residual + sharp = np.clip(sharp, 0, 1) + return soft_mask * sharp + (1 - soft_mask) * img + + +class USMSharp(torch.nn.Module): + + def __init__(self, radius=50, sigma=0): + super(USMSharp, self).__init__() + if radius % 2 == 0: + radius += 1 + self.radius = radius + kernel = cv2.getGaussianKernel(radius, sigma) + kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) + self.register_buffer('kernel', kernel) + + def forward(self, img, weight=0.5, threshold=10): + blur = filter2D(img, self.kernel) + residual = img - blur + + mask = torch.abs(residual) * 255 > threshold + mask = mask.float() + soft_mask = filter2D(mask, self.kernel) + sharp = img + weight * residual + sharp = torch.clip(sharp, 0, 1) + return soft_mask * sharp + (1 - soft_mask) * img diff --git a/ultraled/utils/img_util.py b/ultraled/utils/img_util.py new file mode 100644 index 0000000..3a5f1da --- /dev/null +++ b/ultraled/utils/img_util.py @@ -0,0 +1,172 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + ok = cv2.imwrite(file_path, img, params) + if not ok: + raise IOError('Failed in writing images.') + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] diff --git a/ultraled/utils/lmdb_util.py b/ultraled/utils/lmdb_util.py new file mode 100644 index 0000000..e0a10f6 --- /dev/null +++ b/ultraled/utils/lmdb_util.py @@ -0,0 +1,196 @@ +import cv2 +import lmdb +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def make_lmdb_from_imgs(data_path, + lmdb_path, + img_path_list, + keys, + batch=5000, + compress_level=1, + multiprocessing_read=False, + n_thread=40, + map_size=None): + """Make lmdb from images. + + Contents of lmdb. The file structure is: + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records 1)image name (with extension), + 2)image shape, and 3)compression level, separated by a white space. + + For example, the meta information could be: + `000_00000000.png (720,1280,3) 1`, which means: + 1) image name (with extension): 000_00000000.png; + 2) image shape: (720,1280,3); + 3) compression level: 1 + + We use the image name without extension as the lmdb key. + + If `multiprocessing_read` is True, it will read all the images to memory + using multiprocessing. Thus, your server needs to have enough memory. + + Args: + data_path (str): Data path for reading images. + lmdb_path (str): Lmdb save path. + img_path_list (str): Image path list. + keys (str): Used for lmdb keys. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + multiprocessing_read (bool): Whether use multiprocessing to read all + the images to memory. Default: False. + n_thread (int): For multiprocessing. + map_size (int | None): Map size for lmdb env. If None, use the + estimated size from images. Default: None + """ + + assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' + f'but got {len(img_path_list)} and {len(keys)}') + print(f'Create lmdb for {data_path}, save to {lmdb_path}...') + print(f'Totoal images: {len(img_path_list)}') + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + if multiprocessing_read: + # read all the images to memory (multiprocessing) + dataset = {} # use dict to keep the order for multiprocessing + shapes = {} + print(f'Read images with multiprocessing, #thread: {n_thread} ...') + pbar = tqdm(total=len(img_path_list), unit='image') + + def callback(arg): + """get the image data and update pbar.""" + key, dataset[key], shapes[key] = arg + pbar.update(1) + pbar.set_description(f'Read {key}') + + pool = Pool(n_thread) + for path, key in zip(img_path_list, keys): + pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) + pool.close() + pool.join() + pbar.close() + print(f'Finish reading {len(img_path_list)} images.') + + # create lmdb environment + if map_size is None: + # obtain data size for one image + img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + data_size_per_img = img_byte.nbytes + print('Data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(img_path_list) + map_size = data_size * 10 + + env = lmdb.open(lmdb_path, map_size=map_size) + + # write data to lmdb + pbar = tqdm(total=len(img_path_list), unit='chunk') + txn = env.begin(write=True) + txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + for idx, (path, key) in enumerate(zip(img_path_list, keys)): + pbar.update(1) + pbar.set_description(f'Write {key}') + key_byte = key.encode('ascii') + if multiprocessing_read: + img_byte = dataset[key] + h, w, c = shapes[key] + else: + _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) + h, w, c = img_shape + + txn.put(key_byte, img_byte) + # write meta information + txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') + if idx % batch == 0: + txn.commit() + txn = env.begin(write=True) + pbar.close() + txn.commit() + env.close() + txt_file.close() + print('\nFinish writing lmdb.') + + +def read_img_worker(path, key, compress_level): + """Read image worker. + + Args: + path (str): Image path. + key (str): Image key. + compress_level (int): Compress level when encoding images. + + Returns: + str: Image key. + byte: Image byte. + tuple[int]: Image shape. + """ + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img.ndim == 2: + h, w = img.shape + c = 1 + else: + h, w, c = img.shape + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + return (key, img_byte, (h, w, c)) + + +class LmdbMaker(): + """LMDB Maker. + + Args: + lmdb_path (str): Lmdb save path. + map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + self.lmdb_path = lmdb_path + self.batch = batch + self.compress_level = compress_level + self.env = lmdb.open(lmdb_path, map_size=map_size) + self.txn = self.env.begin(write=True) + self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + self.counter = 0 + + def put(self, img_byte, key, img_shape): + self.counter += 1 + key_byte = key.encode('ascii') + self.txn.put(key_byte, img_byte) + # write meta information + h, w, c = img_shape + self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') + if self.counter % self.batch == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + + def close(self): + self.txn.commit() + self.env.close() + self.txt_file.close() diff --git a/ultraled/utils/logger.py b/ultraled/utils/logger.py new file mode 100644 index 0000000..dcbedf1 --- /dev/null +++ b/ultraled/utils/logger.py @@ -0,0 +1,217 @@ +import datetime +import logging +import time + +from .dist_util import get_dist_info, master_only + +initialized_logger = {} + + +class AvgTimer(): + + def __init__(self, window=200): + self.window = window # average window + self.current_time = 0 + self.total_time = 0 + self.count = 0 + self.avg_time = 0 + self.start() + + def start(self): + self.start_time = self.tic = time.time() + + def record(self): + self.count += 1 + self.toc = time.time() + self.current_time = self.toc - self.tic + self.total_time += self.current_time + # calculate average time + self.avg_time = self.total_time / self.count + + # reset + if self.count > self.window: + self.count = 0 + self.total_time = 0 + + self.tic = time.time() + + def get_current_time(self): + return self.current_time + + def get_avg_time(self): + return self.avg_time + + +class MessageLogger(): + """Message logger for printing. + + Args: + opt (dict): Config. It contains the following keys: + name (str): Exp name. + logger (dict): Contains 'print_freq' (str) for logger interval. + train (dict): Contains 'total_iter' (int) for total iters. + use_tb_logger (bool): Use tensorboard logger. + start_iter (int): Start iter. Default: 1. + tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. + """ + + def __init__(self, opt, start_iter=1, tb_logger=None): + self.exp_name = opt['name'] + self.interval = opt['logger']['print_freq'] + self.start_iter = start_iter + self.max_iters = opt['train']['total_iter'] + self.use_tb_logger = opt['logger']['use_tb_logger'] + self.tb_logger = tb_logger + self.start_time = time.time() + self.logger = get_root_logger() + + def reset_start_time(self): + self.start_time = time.time() + + @master_only + def __call__(self, log_vars): + """Format logging message. + + Args: + log_vars (dict): It contains the following keys: + epoch (int): Epoch number. + iter (int): Current iter. + lrs (list): List for learning rates. + + time (float): Iter time. + data_time (float): Data time for each iter. + """ + # epoch, iter, learning rates + epoch = log_vars.pop('epoch') + current_iter = log_vars.pop('iter') + lrs = log_vars.pop('lrs') + + message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(') + for v in lrs: + message += f'{v:.3e},' + message += ')] ' + + # time and estimated time + if 'time' in log_vars.keys(): + iter_time = log_vars.pop('time') + data_time = log_vars.pop('data_time') + + total_time = time.time() - self.start_time + time_sec_avg = total_time / (current_iter - self.start_iter + 1) + eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + message += f'[eta: {eta_str}, ' + message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' + + # other items, especially losses + for k, v in log_vars.items(): + if isinstance(v, str): + message += f'{k}: {v} ' + continue + else: + message += f'{k}: {v:.4e} ' + # tensorboard logger + if self.use_tb_logger and 'debug' not in self.exp_name: + if k.startswith('l_'): + self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) + else: + self.tb_logger.add_scalar(k, v, current_iter) + self.logger.info(message) + + +@master_only +def init_tb_logger(log_dir): + from torch.utils.tensorboard import SummaryWriter + tb_logger = SummaryWriter(log_dir=log_dir) + return tb_logger + + +@master_only +def init_wandb_logger(opt): + """We now only use wandb to sync tensorboard log.""" + import wandb + logger = get_root_logger() + + project = opt['logger']['wandb']['project'] + resume_id = opt['logger']['wandb'].get('resume_id') + if resume_id: + wandb_id = resume_id + resume = 'allow' + logger.warning(f'Resume wandb logger with id={wandb_id}.') + else: + wandb_id = wandb.util.generate_id() + resume = 'never' + + wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) + + logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') + + +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + elif log_file is not None: + logger.setLevel(log_level) + # add file handler + file_handler = logging.FileHandler(log_file, 'w') + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +def get_env_info(): + """Get environment information. + + Currently, only log the software version. + """ + import torch + import torchvision + + # from basicsr.version import __version__ +# msg = r""" +# ____ _ _____ ____ +# / __ ) ____ _ _____ (_)_____/ ___/ / __ \ +# / __ |/ __ `// ___// // ___/\__ \ / /_/ / +# / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ +# /_____/ \__,_//____//_/ \___//____//_/ |_| +# ______ __ __ __ __ +# / ____/____ ____ ____/ / / / __ __ _____ / /__ / / +# / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / +# / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ +# \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) +# """ +# msg += ('\nVersion Information: ' +# f'\n\tBasicSR: {__version__}' +# f'\n\tPyTorch: {torch.__version__}' +# f'\n\tTorchVision: {torchvision.__version__}') +# return msg diff --git a/ultraled/utils/matlab_functions.py b/ultraled/utils/matlab_functions.py new file mode 100644 index 0000000..a201f79 --- /dev/null +++ b/ultraled/utils/matlab_functions.py @@ -0,0 +1,178 @@ +import math +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( + out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + squeeze_flag = False + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + if img.ndim == 2: + img = img[:, :, None] + squeeze_flag = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + if img.ndim == 2: + img = img.unsqueeze(0) + squeeze_flag = True + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width, + antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width, + antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) + + if squeeze_flag: + out_2 = out_2.squeeze(0) + if numpy_type: + out_2 = out_2.numpy() + if not squeeze_flag: + out_2 = out_2.transpose(1, 2, 0) + + return out_2 diff --git a/ultraled/utils/misc.py b/ultraled/utils/misc.py new file mode 100644 index 0000000..728fef8 --- /dev/null +++ b/ultraled/utils/misc.py @@ -0,0 +1,141 @@ +import numpy as np +import os +import random +import time +import torch +from os import path as osp + +from .dist_util import master_only + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + if osp.exists(path): + new_name = path + '_archived_' + get_time_str() + print(f'Path already exists. Rename it to {new_name}', flush=True) + os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt['path'].copy() + if opt['is_train']: + mkdir_and_rename(path_opt.pop('experiments_root')) + else: + mkdir_and_rename(path_opt.pop('results_root')) + for key, path in path_opt.items(): + if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): + continue + else: + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + print('pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or (network + not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + print(f"Set {name} to {opt['path'][name]}") + + # change param_key to params in resume + param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] + for param_key in param_keys: + if opt['path'][param_key] == 'params_ema': + opt['path'][param_key] = 'params' + print(f'Set {param_key} to params') + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formatted file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/ultraled/utils/options.py b/ultraled/utils/options.py new file mode 100644 index 0000000..09b749b --- /dev/null +++ b/ultraled/utils/options.py @@ -0,0 +1,211 @@ +import argparse +import random +import torch +import yaml +from collections import OrderedDict +from os import path as osp +import os +from ultraled.utils import set_random_seed +from ultraled.utils.dist_util import get_dist_info, init_dist, master_only + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + tuple: yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def yaml_load(f): + """Load yaml file or string. + + Args: + f (str): File path or a python string. + + Returns: + dict: Loaded dict. + """ + if os.path.isfile(f): + with open(f, 'r') as f: + return yaml.load(f, Loader=ordered_yaml()[0]) + else: + return yaml.load(f, Loader=ordered_yaml()[0]) + + + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = '\n' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_level * 2) + k + ':[' + msg += dict2str(v, indent_level + 1) + msg += ' ' * (indent_level * 2) + ']\n' + else: + msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' + return msg + + +def _postprocess_yml_value(value): + # None + if value == '~' or value.lower() == 'none': + return None + # bool + if value.lower() == 'true': + return True + elif value.lower() == 'false': + return False + # !!float number + if value.startswith('!!float'): + return float(value.replace('!!float', '')) + # number + if value.isdigit(): + return int(value) + elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: + return float(value) + # list + if value.startswith('['): + return eval(value) + # str + return value + + +def parse_options(root_path, is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') + parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') + parser.add_argument('--auto_resume', action='store_true') + parser.add_argument('--debug', action='store_true') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') + args = parser.parse_args() + + # parse yml to dict + with open(args.opt, mode='r') as f: + opt = yaml.load(f, Loader=ordered_yaml()[0]) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + # force to update yml options + if args.force_yml is not None: + for entry in args.force_yml: + # now do not support creating new keys + keys, value = entry.split('=') + keys, value = keys.strip(), value.strip() + value = _postprocess_yml_value(value) + eval_str = 'opt' + for key in keys.split(':'): + eval_str += f'["{key}"]' + eval_str += '=value' + # using exec function + exec(eval_str) + + opt['auto_resume'] = args.auto_resume + opt['is_train'] = is_train + + # debug setting + if args.debug and not opt['name'].startswith('debug'): + opt['name'] = 'debug_' + opt['name'] + + if opt['num_gpu'] == 'auto': + opt['num_gpu'] = torch.cuda.device_count() + + # datasets + for phase, dataset in opt['datasets'].items(): + # for multiple datasets, e.g., val_1, val_2; test_1, test_2 + phase = phase.split('_')[0] + dataset['phase'] = phase + if 'scale' in opt: + dataset['scale'] = opt['scale'] + if dataset.get('dataroot_gt') is not None: + dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) + if dataset.get('dataroot_lq') is not None: + dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) + + # paths + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) + + if is_train: + experiments_root = osp.join(root_path, 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_states'] = osp.join(experiments_root, 'training_states') + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, 'visualization') + + # change some options for debug mode + if 'debug' in opt['name']: + if 'val' in opt: + opt['val']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + results_root = osp.join(root_path, 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + opt['path']['visualization'] = osp.join(results_root, 'visualization') + + return opt, args + + +@master_only +def copy_opt_file(opt_file, experiments_root): + # copy the yml file to the experiment root + import sys + import time + from shutil import copyfile + cmd = ' '.join(sys.argv) + filename = osp.join(experiments_root, osp.basename(opt_file)) + copyfile(opt_file, filename) + + with open(filename, 'r+') as f: + lines = f.readlines() + lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') + f.seek(0) + f.writelines(lines) diff --git a/ultraled/utils/process.py b/ultraled/utils/process.py new file mode 100644 index 0000000..32e302b --- /dev/null +++ b/ultraled/utils/process.py @@ -0,0 +1,223 @@ +"""Forward processing of raw data to sRGB images. + +Unprocessing Images for Learned Raw Denoising +http://timothybrooks.com/tech/unprocessing +""" + +import numpy as np +import torch +from ultraled.utils.torchinterp1d import Interp1d +from os.path import join + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +def apply_gains(bayer_images, wbs): + """Applies white balance to a batch of Bayer images.""" + N, C, _, _ = bayer_images.shape + outs = bayer_images * wbs.view(N, C, 1, 1) + return outs + +def apply_ccms(images, ccms): + """Applies color correction matrices.""" + images = images.permute( + 0, 2, 3, 1) # Permute the image tensor to BxHxWxC format from BxCxHxW format + images = images[:, :, :, None, :] + ccms = ccms[:, None, None, :, :] + outs = torch.sum(images * ccms, dim=-1) + # Re-Permute the tensor back to BxCxHxW format + outs = outs.permute(0, 3, 1, 2) + return outs + + +def gamma_compression(images, gamma=2.2): + """Converts from linear to gamma space.""" + outs = torch.clamp(images, min=1e-8) ** (1 / gamma) + # outs = (1 + gamma[0]) * np.power(images, 1.0/gamma[1]) - gamma[0] + gamma[2]*images + outs = torch.clamp((outs*255).int(), min=0, max=255).float() / 255 + return outs + + +def gamma_compression_grad(images, gamma=2.2): + """Converts from linear to gamma space.""" + outs = torch.clamp(images, min=1e-8) ** (1 / gamma) + # outs = (1 + gamma[0]) * np.power(images, 1.0/gamma[1]) - gamma[0] + gamma[2]*images + return outs + + +def binning(bayer_images): + """RGBG -> RGB""" + lin_rgb = torch.stack([ + bayer_images[:,0,...], + torch.mean(bayer_images[:, [1,3], ...], dim=1), + bayer_images[:,2,...]], dim=1) + + return lin_rgb + + +def process(bayer_images, wbs, cam2rgbs, gamma=2.2, CRF=None): + """Processes a batch of Bayer RGBG images into sRGB images.""" + orig_img = bayer_images + # White balance. + bayer_images = apply_gains(orig_img, wbs) + # Binning + bayer_images = torch.clamp(bayer_images, min=0.0, max=1.0) + images = binning(bayer_images) + # Color correction. + images = apply_ccms(images, cam2rgbs) + # Gamma compression. + images = torch.clamp(images, min=0.0, max=1.0) + if CRF is None: + images = gamma_compression(images, gamma) + else: + images = camera_response_function(images, CRF) + + return images + + +def process_grad(bayer_images, wbs, cam2rgbs, gamma=2.2, CRF=None): + """Processes a batch of Bayer RGBG images into sRGB images.""" + orig_img = bayer_images + # White balance. + bayer_images = apply_gains(orig_img, wbs) + # Binning + bayer_images = torch.clamp(bayer_images, min=0.0, max=1.0) + images = binning(bayer_images) + # Color correction. + images = apply_ccms(images, cam2rgbs) + # Gamma compression. + images = torch.clamp(images, min=0.0, max=1.0) + if CRF is None: + images = gamma_compression_grad(images, gamma) + else: + images = camera_response_function_grad(images, CRF) + + return images + + +def camera_response_function(images, CRF): + E, fs = CRF # unpack CRF data + + outs = torch.zeros_like(images) + device = images.device + + for i in range(images.shape[0]): + img = images[i].view(3, -1) + out = Interp1d()(E.to(device), fs.to(device), img) + outs[i, ...] = out.view(3, images.shape[2], images.shape[3]) + + outs = torch.clamp((outs*255).int(), min=0, max=255).float() / 255 + return outs + +def camera_response_function_grad(images, CRF): + E, fs = CRF # unpack CRF data + + outs = torch.zeros_like(images) + device = images.device + + for i in range(images.shape[0]): + img = images[i].view(3, -1) + out = Interp1d()(E.to(device), fs.to(device), img) + outs[i, ...] = out.view(3, images.shape[2], images.shape[3]) + + return outs + +def raw2rgb(packed_raw, raw, CRF=None, gamma=2.2): + """Raw2RGB pipeline (preprocess version)""" + wb = np.array(raw.camera_whitebalance) + wb /= wb[1] + cam2rgb = raw.rgb_camera_matrix[:3, :3] + + if isinstance(packed_raw, np.ndarray): + packed_raw = torch.from_numpy(packed_raw).float() + + wb = torch.from_numpy(wb).float().to(packed_raw.device) + cam2rgb = torch.from_numpy(cam2rgb).float().to(packed_raw.device) + + out = process(packed_raw[None], wbs=wb[None], cam2rgbs=cam2rgb[None], gamma=gamma, CRF=CRF)[0, ...].numpy() + + return out + + +def raw2rgb_v2(packed_raw, wb, ccm, CRF=None, gamma=2.2): # RGBG + packed_raw = torch.from_numpy(packed_raw).float() + wb = torch.from_numpy(wb).float() + cam2rgb = torch.from_numpy(ccm).float() + out = process(packed_raw[None], wbs=wb[None], cam2rgbs=cam2rgb[None], gamma=gamma, CRF=CRF)[0, ...].numpy() + return out + + +def raw2rgb_torch(packed_raw, wb, ccm, CRF=None, gamma=2.2, batch=False): # RGBG + if batch: + out = process(packed_raw, wbs=wb, cam2rgbs=ccm, gamma=gamma, CRF=CRF) + else: + out = process(packed_raw[None], wbs=wb[None], cam2rgbs=ccm[None], gamma=gamma, CRF=CRF) + return out + +def raw2rgb_torch_grad(packed_raw, wb, ccm, CRF=None, gamma=2.2): # RGBG + out = process_grad(packed_raw, wbs=wb, cam2rgbs=ccm, gamma=gamma, CRF=CRF) + return out + +def raw2rgb_postprocess(packed_raw, raw, CRF=None): + """Raw2RGB pipeline (postprocess version)""" + assert packed_raw.ndimension() == 4 and packed_raw.shape[0] == 1 + wb = np.array(raw.camera_whitebalance) + wb /= wb[1] + cam2rgb = raw.rgb_camera_matrix[:3, :3] + + wb = torch.from_numpy(wb[None]).float().to(packed_raw.device) + cam2rgb = torch.from_numpy(cam2rgb[None]).float().to(packed_raw.device) + out = process(packed_raw, wbs=wb, cam2rgbs=cam2rgb, gamma=2.2, CRF=CRF) + return out + +def read_wb_ccm(raw): + wb = np.array(raw.camera_whitebalance) + wb /= wb[1] + wb = wb.astype(np.float32) + ccm = raw.rgb_camera_matrix[:3, :3].astype(np.float32) + return wb, ccm + + +def read_emor(address): + def _read_curve(lst): + curve = [l.strip() for l in lst] + curve = ' '.join(curve) + curve = np.array(curve.split()).astype(np.float32) + return curve + + with open(address) as f: + lines = f.readlines() + k = 1 + E = _read_curve(lines[k:k+256]) + k += 257 + f0 = _read_curve(lines[k:k+256]) + hs = [] + for _ in range(25): + k += 257 + hs.append(_read_curve(lines[k:k+256])) + + hs = np.array(hs) + + return E, f0, hs + + +def read_dorf(address): + with open(address) as f: + lines = f.readlines() + curve_names = lines[0::6] + Es = lines[3::6] + Bs = lines[5::6] + + Es = [np.array(E.strip().split()).astype(np.float32) for E in Es] + Bs = [np.array(B.strip().split()).astype(np.float32) for B in Bs] + + return curve_names, Es, Bs + + +def load_CRF(EMoR_path): + # init CRF function + fs = np.loadtxt(join(EMoR_path, 'CRF_SonyA7S2_5.txt')) + E, _, _ = read_emor(join(EMoR_path, 'emor.txt')) + E = torch.from_numpy(E).repeat(3, 1) + fs = torch.from_numpy(fs) + CRF = (E, fs) + return CRF diff --git a/ultraled/utils/registry.py b/ultraled/utils/registry.py new file mode 100644 index 0000000..5e72ef7 --- /dev/null +++ b/ultraled/utils/registry.py @@ -0,0 +1,88 @@ +# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 + + +class Registry(): + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj, suffix=None): + if isinstance(suffix, str): + name = name + '_' + suffix + + assert (name not in self._obj_map), (f"An object named '{name}' was already registered " + f"in '{self._name}' registry!") + self._obj_map[name] = obj + + def register(self, obj=None, suffix=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. + See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class): + name = func_or_class.__name__ + self._do_register(name, func_or_class, suffix) + return func_or_class + + return deco + + # used as a function call + name = obj.__name__ + self._do_register(name, obj, suffix) + + def get(self, name, suffix='basicsr'): + ret = self._obj_map.get(name) + if ret is None: + ret = self._obj_map.get(name + '_' + suffix) + print(f'Name {name} is not found, use name: {name}_{suffix}!') + if ret is None: + raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") + return ret + + def __contains__(self, name): + return name in self._obj_map + + def __iter__(self): + return iter(self._obj_map.items()) + + def keys(self): + return self._obj_map.keys() + + +DATASET_REGISTRY = Registry('dataset') +ARCH_REGISTRY = Registry('arch') +MODEL_REGISTRY = Registry('model') +LOSS_REGISTRY = Registry('loss') +METRIC_REGISTRY = Registry('metric') diff --git a/ultraled/utils/torchinterp1d.py b/ultraled/utils/torchinterp1d.py new file mode 100644 index 0000000..5e7d1ff --- /dev/null +++ b/ultraled/utils/torchinterp1d.py @@ -0,0 +1,164 @@ +import torch +import contextlib + +class Interp1d(torch.autograd.Function): + def __call__(self, x, y, xnew, out=None): + return self.forward(x, y, xnew, out) + + def forward(ctx, x, y, xnew, out=None): + """ + Linear 1D interpolation on the GPU for Pytorch. + This function returns interpolated values of a set of 1-D functions at + the desired query points `xnew`. + This function is working similarly to Matlab™ or scipy functions with + the `linear` interpolation mode on, except that it parallelises over + any number of desired interpolation problems. + The code will run on GPU if all the tensors provided are on a cuda + device. + + Parameters + ---------- + x : (N, ) or (D, N) Pytorch Tensor + A 1-D or 2-D tensor of real values. + y : (N,) or (D, N) Pytorch Tensor + A 1-D or 2-D tensor of real values. The length of `y` along its + last dimension must be the same as that of `x` + xnew : (P,) or (D, P) Pytorch Tensor + A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if + _both_ `x` and `y` are 1-D. Otherwise, its length along the first + dimension must be the same as that of whichever `x` and `y` is 2-D. + out : Pytorch Tensor, same shape as `xnew` + Tensor for the output. If None: allocated automatically. + + """ + # making the vectors at least 2D + is_flat = {} + require_grad = {} + v = {} + device = [] + eps = torch.finfo(y.dtype).eps + for name, vec in {'x': x, 'y': y, 'xnew': xnew}.items(): + assert len(vec.shape) <= 2, 'interp1d: all inputs must be '\ + 'at most 2-D.' + if len(vec.shape) == 1: + v[name] = vec[None, :] + else: + v[name] = vec + is_flat[name] = v[name].shape[0] == 1 + require_grad[name] = vec.requires_grad + device = list(set(device + [str(vec.device)])) + assert len(device) == 1, 'All parameters must be on the same device.' + device = device[0] + + # Checking for the dimensions + assert (v['x'].shape[1] == v['y'].shape[1] + and ( + v['x'].shape[0] == v['y'].shape[0] + or v['x'].shape[0] == 1 + or v['y'].shape[0] == 1 + ) + ), ("x and y must have the same number of columns, and either " + "the same number of row or one of them having only one " + "row.") + + reshaped_xnew = False + if ((v['x'].shape[0] == 1) and (v['y'].shape[0] == 1) + and (v['xnew'].shape[0] > 1)): + # if there is only one row for both x and y, there is no need to + # loop over the rows of xnew because they will all have to face the + # same interpolation problem. We should just stack them together to + # call interp1d and put them back in place afterwards. + original_xnew_shape = v['xnew'].shape + v['xnew'] = v['xnew'].contiguous().view(1, -1) + reshaped_xnew = True + + # identify the dimensions of output and check if the one provided is ok + D = max(v['x'].shape[0], v['xnew'].shape[0]) + shape_ynew = (D, v['xnew'].shape[-1]) + if out is not None: + if out.numel() != shape_ynew[0]*shape_ynew[1]: + # The output provided is of incorrect shape. + # Going for a new one + out = None + else: + ynew = out.reshape(shape_ynew) + if out is None: + ynew = torch.zeros(*shape_ynew, device=device) + + # moving everything to the desired device in case it was not there + # already (not handling the case things do not fit entirely, user will + # do it if required.) + for name in v: + v[name] = v[name].to(device) + + # calling searchsorted on the x values. + ind = ynew.long() + + # expanding xnew to match the number of rows of x in case only one xnew is + # provided + if v['xnew'].shape[0] == 1: + v['xnew'] = v['xnew'].expand(v['x'].shape[0], -1) + + torch.searchsorted(v['x'].contiguous(), + v['xnew'].contiguous(), out=ind) + + # the `-1` is because searchsorted looks for the index where the values + # must be inserted to preserve order. And we want the index of the + # preceeding value. + ind -= 1 + # we clamp the index, because the number of intervals is x.shape-1, + # and the left neighbour should hence be at most number of intervals + # -1, i.e. number of columns in x -2 + ind = torch.clamp(ind, 0, v['x'].shape[1] - 1 - 1) + + # helper function to select stuff according to the found indices. + def sel(name): + if is_flat[name]: + return v[name].contiguous().view(-1)[ind] + return torch.gather(v[name], 1, ind) + + # activating gradient storing for everything now + enable_grad = False + saved_inputs = [] + for name in ['x', 'y', 'xnew']: + if require_grad[name]: + enable_grad = True + saved_inputs += [v[name]] + else: + saved_inputs += [None, ] + # assuming x are sorted in the dimension 1, computing the slopes for + # the segments + is_flat['slopes'] = is_flat['x'] + # now we have found the indices of the neighbors, we start building the + # output. Hence, we start also activating gradient tracking + with torch.enable_grad() if enable_grad else contextlib.suppress(): + v['slopes'] = ( + (v['y'][:, 1:]-v['y'][:, :-1]) + / + (eps + (v['x'][:, 1:]-v['x'][:, :-1])) + ) + + # now build the linear interpolation + ynew = sel('y') + sel('slopes')*( + v['xnew'] - sel('x')) + + if reshaped_xnew: + ynew = ynew.view(original_xnew_shape) + + ctx.save_for_backward(ynew, *saved_inputs) + return ynew + + @staticmethod + def backward(ctx, grad_out): + inputs = ctx.saved_tensors[1:] + gradients = torch.autograd.grad( + ctx.saved_tensors[0], + [i for i in inputs if i is not None], + grad_out, retain_graph=True) + result = [None, ] * 5 + pos = 0 + for index in range(len(inputs)): + if inputs[index] is not None: + result[index] = gradients[pos] + pos += 1 + return (*result,) \ No newline at end of file