11from pathlib import Path
22from typing import Any , Dict , Optional
33
4+ from attr import dataclass
45from wsgidav .dav_error import HTTP_FORBIDDEN , DAVError
56from wsgidav .fs_dav_provider import FileResource , FilesystemProvider , FolderResource
67
78from .token import Token
8- from .type_alias import PreWriteType
9- from .util import requests_session
9+ from .type_alias import WriteType
10+ from .util import cattrib , requests_session
1011
1112
1213class ManabiFolderResource (FolderResource ):
@@ -51,18 +52,26 @@ def set_last_modified(self, dest_path, time_stamp, dry_run):
5152 raise DAVError (HTTP_FORBIDDEN )
5253
5354
55+ @dataclass
56+ class CallbackHookConfig :
57+ pre_write_hook : Optional [str ] = cattrib (Optional [str ], default = None )
58+ pre_write_callback : Optional [WriteType ] = cattrib (Optional [WriteType ], default = None )
59+ post_write_hook : Optional [str ] = cattrib (Optional [str ], default = None )
60+ post_write_callback : Optional [WriteType ] = cattrib (
61+ Optional [WriteType ], default = None
62+ )
63+
64+
5465class ManabiFileResource (FileResource ):
5566 def __init__ (
5667 self ,
5768 path ,
5869 environ ,
5970 file_path ,
6071 * ,
61- pre_write_hook : Optional [str ] = None ,
62- pre_write_callback : Optional [PreWriteType ] = None ,
72+ cb_hook_config : Optional [CallbackHookConfig ] = None ,
6373 ):
64- self ._pre_write_hook = pre_write_hook
65- self ._pre_write_callback = pre_write_callback
74+ self ._cb_config = cb_hook_config
6675 self ._token = environ ["manabi.token" ]
6776 super ().__init__ (path , environ , file_path )
6877
@@ -78,22 +87,46 @@ def support_recursive_move(self, dest_path):
7887 def move_recursive (self , dest_path ):
7988 raise DAVError (HTTP_FORBIDDEN )
8089
81- def begin_write (self , * , content_type ):
90+ def get_token_and_config (self ):
8291 token = self ._token
83- if token :
84- pre_hook = self ._pre_write_hook
85- pre_callback = self ._pre_write_callback
86-
87- if pre_hook :
88- session = requests_session ()
89- res = session .post (pre_hook , data = token .encode ())
90- if res .status_code != 200 :
91- raise DAVError (HTTP_FORBIDDEN )
92- if pre_callback :
93- if not pre_callback (token ):
94- raise DAVError (HTTP_FORBIDDEN )
95- # The hook returned and hopefully created a new version.
96- # Now we can save.
92+ config = self ._cb_config
93+ return token and config , token , config
94+
95+ def process_post_write_hooks (self ):
96+ ok , token , config = self .get_token_and_config ()
97+ if not ok :
98+ return
99+ post_hook = config .post_write_hook
100+ post_callback = config .post_write_callback
101+
102+ if post_hook :
103+ session = requests_session ()
104+ session .post (post_hook , data = token .encode ())
105+ if post_callback :
106+ post_callback (token )
107+
108+ def end_write (self , * , with_errors ):
109+ if not with_errors :
110+ self .process_post_write_hooks ()
111+
112+ def process_pre_write_hooks (self ):
113+ ok , token , config = self .get_token_and_config ()
114+ if not ok :
115+ return
116+ pre_hook = config .pre_write_hook
117+ pre_callback = config .pre_write_callback
118+
119+ if pre_hook :
120+ session = requests_session ()
121+ res = session .post (pre_hook , data = token .encode ())
122+ if res .status_code != 200 :
123+ raise DAVError (HTTP_FORBIDDEN )
124+ if pre_callback :
125+ if not pre_callback (token ):
126+ raise DAVError (HTTP_FORBIDDEN )
127+
128+ def begin_write (self , * , content_type ):
129+ self .process_pre_write_hooks ()
97130 return super ().begin_write (content_type = content_type )
98131
99132
@@ -104,11 +137,9 @@ def __init__(
104137 * ,
105138 readonly = False ,
106139 shadow = None ,
107- pre_write_hook : Optional [str ] = None ,
108- pre_write_callback : Optional [PreWriteType ] = None ,
140+ cb_hook_config : Optional [CallbackHookConfig ] = None ,
109141 ):
110- self ._pre_write_hook = pre_write_hook
111- self ._pre_write_callback = pre_write_callback
142+ self ._cb_hook_config = cb_hook_config
112143 super ().__init__ (root_folder , readonly = readonly , shadow = shadow )
113144
114145 def get_resource_inst (self , path : str , environ : Dict [str , Any ]):
@@ -128,8 +159,7 @@ def get_resource_inst(self, path: str, environ: Dict[str, Any]):
128159 path ,
129160 environ ,
130161 fp ,
131- pre_write_hook = self ._pre_write_hook ,
132- pre_write_callback = self ._pre_write_callback ,
162+ cb_hook_config = self ._cb_hook_config ,
133163 )
134164 else :
135165 return None
0 commit comments