@@ -2,6 +2,7 @@ package encryption
22
33import (
44 "encoding/base64"
5+ "errors"
56 "fmt"
67 "reflect"
78 "slices"
@@ -83,6 +84,78 @@ func EncryptFields(v any) (any, error) {
8384 return out , nil
8485}
8586
87+ // Decrypts everything from raw map to v. v has to be a pointer to struct.
88+ func DecryptFields (v any , raw map [string ]any ) error {
89+ if siv == nil {
90+ return nil // skip if no key/instance
91+ }
92+
93+ val := reflect .ValueOf (v )
94+ if val .Kind () != reflect .Pointer || val .Elem ().Kind () != reflect .Struct {
95+ return errors .New ("v must be a pointer to a struct" )
96+ }
97+
98+ val = val .Elem ()
99+ typ := val .Type ()
100+
101+ for i := 0 ; i < val .NumField (); i ++ {
102+ fieldValue := val .Field (i )
103+ fieldType := typ .Field (i )
104+ fieldName := fieldType .Name
105+
106+ // skip if field is unexported or marked to be ignored
107+ if ! fieldValue .CanSet () || fieldType .Tag .Get ("json" ) == "-" {
108+ continue
109+ }
110+
111+ // get the value from the map using the struct field name
112+ data , ok := raw [getJsonFieldName (fieldType )]
113+ if ! ok || data == nil {
114+ continue
115+ }
116+
117+ // recursive structs
118+ if fieldValue .Kind () == reflect .Struct || (fieldValue .Kind () == reflect .Pointer && fieldType .Type .Elem ().Kind () == reflect .Struct ) {
119+ // don't recurse time.Time
120+ if _ , isTime := fieldValue .Interface ().(time.Time ); ! isTime {
121+ nestedMap , isMap := data .(map [string ]any )
122+ if isMap {
123+ // initialize pointer if nil
124+ if fieldValue .Kind () == reflect .Pointer && fieldValue .IsNil () {
125+ fieldValue .Set (reflect .New (fieldValue .Type ().Elem ()))
126+ }
127+ if err := DecryptFields (fieldValue .Addr ().Interface (), nestedMap ); err != nil {
128+ return err
129+ }
130+ continue
131+ }
132+ }
133+ }
134+
135+ // decrypt encrypted fields
136+ cipherStr , ok := data .(string )
137+ if ! ok {
138+ continue
139+ }
140+
141+ ciphertext , err := base64 .StdEncoding .DecodeString (cipherStr )
142+ if err != nil {
143+ return fmt .Errorf ("failed to decode base64 string of field %s: %w" , fieldName , err )
144+ }
145+ plaintext , err := siv .Open (nil , nil , ciphertext , []byte (fieldName ))
146+ if err != nil {
147+ return fmt .Errorf ("failed to decrypt field %s: %w" , fieldName , err )
148+ }
149+
150+ // convert decrypted string back to the field's actual type
151+ if err := decodeValue (fieldValue , string (plaintext )); err != nil {
152+ return fmt .Errorf ("failed to parse field %s: %w" , fieldName , err )
153+ }
154+ }
155+
156+ return nil
157+ }
158+
86159// Helper that returns the string representation of field type.
87160func encodeValue (field reflect.Value ) (string , error ) {
88161 // dereference pointer if necessary
@@ -108,9 +181,39 @@ func encodeValue(field reflect.Value) (string, error) {
108181 }
109182}
110183
184+ // Helper that converts string representation of a value back into its field type.
185+ func decodeValue (field reflect.Value , val string ) error {
186+ switch field .Kind () {
187+ case reflect .String :
188+ field .SetString (val )
189+ case reflect .Int , reflect .Int64 , reflect .Int32 , reflect .Int16 , reflect .Int8 :
190+ i , err := strconv .ParseInt (val , 10 , 64 )
191+ if err != nil {
192+ return err
193+ }
194+ field .SetInt (i )
195+ case reflect .Struct :
196+ if _ , ok := field .Interface ().(time.Time ); ok {
197+ t , err := time .Parse (time .RFC3339 , val )
198+ if err != nil {
199+ return err
200+ }
201+ field .Set (reflect .ValueOf (t ))
202+ }
203+ }
204+ return nil
205+ }
206+
111207// Returns true if `json:"field_name,omitzero"` and false if `json:"field_name"`.
112208func hasOmitzero (field reflect.StructField ) bool {
113209 tag := field .Tag .Get ("json" )
114210 tagParts := strings .Split (tag , "," )
115211 return slices .Contains (tagParts , "omitzero" )
116212}
213+
214+ // Returns the field_name from `json:"field_name"`.
215+ func getJsonFieldName (field reflect.StructField ) string {
216+ tag := field .Tag .Get ("json" )
217+ tagParts := strings .Split (tag , "," )
218+ return tagParts [0 ] // always len > 0 -> safe
219+ }
0 commit comments