2020
2121import re
2222from io import BytesIO
23- from typing import Iterator , Union , Optional , Tuple , List
23+ from typing import Generic , Iterator , Type , TypeVar , Union , Optional , Tuple , List
2424from urllib .parse import parse_qs
2525from wsgiref .headers import Headers
2626from collections .abc import MutableMapping as DictMixin
@@ -280,8 +280,10 @@ def parse_options_header(header, options=None, unquote=header_unquote):
280280_BODY = "BODY"
281281_COMPLETE = "END"
282282
283+ t_segment = TypeVar ('SegmentType' , bound = "MultipartSegment" )
284+
285+ class PushMultipartParser (Generic [t_segment ]):
283286
284- class PushMultipartParser :
285287 def __init__ (
286288 self ,
287289 boundary : Union [str , bytes ],
@@ -292,6 +294,7 @@ def __init__(
292294 max_segment_count = inf , # unlimited
293295 header_charset = "utf8" ,
294296 strict = False ,
297+ segment_class : Optional [Type [t_segment ]] = None ,
295298 ):
296299 """A push-based (incremental, non-blocking) parser for multipart/form-data.
297300
@@ -311,6 +314,8 @@ def __init__(
311314 :param max_segment_count: Maximum number of segments.
312315 :param header_charset: Charset for header names and values.
313316 :param strict: Enables additional format and sanity checks.
317+
318+ :param segment_class: Class for emitted segments, defaults to `MultipartSegment`.
314319 """
315320 self .boundary = to_bytes (boundary )
316321 self .content_length = content_length
@@ -321,13 +326,17 @@ def __init__(
321326 self .max_segment_count = max_segment_count
322327 self .strict = strict
323328
324- self ._delimiter = b"\r \n --" + self .boundary
329+ if segment_class and issubclass (self .segment_class , MultipartSegment ):
330+ self .segment_class = segment_class
331+ else :
332+ self .segment_class = MultipartSegment
325333
326334 # Internal parser state
335+ self ._delimiter = b"\r \n --" + self .boundary
327336 self ._parsed = 0
328- self ._fieldcount = 0
329337 self ._buffer = bytearray ()
330- self ._current = None
338+ self ._segment_count = 0
339+ self ._segment = None
331340 self ._state = _PREAMBLE
332341
333342 #: True if the parser reached the end of the multipart stream, stopped
@@ -344,7 +353,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
344353
345354 def parse (
346355 self , chunk : Union [bytes , bytearray ]
347- ) -> Iterator [Union ["MultipartSegment" , bytearray , None ]]:
356+ ) -> Iterator [Union [t_segment , bytearray , None ]]:
348357 """Parse a chunk of data and yield as many result objects as possible
349358 with the data given.
350359
@@ -406,7 +415,7 @@ def parse(
406415 tail = buffer [next_start - 2 : next_start ]
407416
408417 if tail == b"\r \n " : # Normal delimiter found
409- self ._current = MultipartSegment ( self )
418+ self ._segment = self . _new_segment ( )
410419 self ._state = _HEADER
411420 offset = next_start
412421 continue
@@ -433,12 +442,12 @@ def parse(
433442 nl = buffer .find (b"\r \n " , offset )
434443
435444 if nl > offset : # Non-empty header line
436- self ._current . _add_headerline (buffer [offset :nl ])
445+ self ._segment . _on_headerline (buffer [offset :nl ])
437446 offset = nl + 2
438447 continue
439448 elif nl == offset : # Empty header line -> End of header section
440- self ._current . _close_headers ()
441- yield self ._current
449+ self ._segment . _on_header_complete ()
450+ yield self ._segment
442451 self ._state = _BODY
443452 offset += 2
444453 continue
@@ -463,27 +472,25 @@ def parse(
463472
464473 if tail == b"\r \n " or tail == b"--" :
465474 if index > offset :
466- self ._current ._update_size (index - offset )
467- yield buffer [offset :index ]
475+ yield self ._segment ._on_data (buffer [offset :index ])
468476
469477 offset = next_start
470- self ._current . _mark_complete ()
478+ self ._segment . _on_data_complete ()
471479 yield None # End of segment
472480
473481 if tail == b"--" : # Last delimiter
474482 self ._state = _COMPLETE
475483 break
476484 else : # Normal delimiter
477- self ._current = MultipartSegment ( self )
485+ self ._segment = self . _new_segment ( )
478486 self ._state = _HEADER
479487 continue
480488
481489 # Keep enough in buffer to accout for a partial delimiter at
482490 # the end, but emiot the rest.
483491 chunk_end = bufferlen - (d_len + 1 )
484492 assert chunk_end > offset # Always true
485- self ._current ._update_size (chunk_end - offset )
486- yield buffer [offset :chunk_end ]
493+ yield self ._segment ._on_data (buffer [offset :chunk_end ])
487494 offset = chunk_end
488495 break # wait for more data
489496
@@ -501,6 +508,12 @@ def parse(
501508 self .close (check_complete = False )
502509 raise
503510
511+ def _new_segment (self ) -> t_segment :
512+ self ._segment_count += 1
513+ if self ._segment_count > self .max_segment_count :
514+ raise ParserLimitReached ("Maximum segment count exceeded" )
515+ return self .segment_class (self )
516+
504517 def close (self , check_complete = True ):
505518 """
506519 Close this parser if not already closed.
@@ -510,7 +523,7 @@ def close(self, check_complete=True):
510523 """
511524
512525 self .closed = True
513- self ._current = None
526+ self ._segment = None
514527 del self ._buffer [:]
515528
516529 if check_complete and self ._state is not _COMPLETE :
@@ -551,39 +564,34 @@ class MultipartSegment:
551564 def __init__ (self , parser : PushMultipartParser ):
552565 """ Private constructor, used by :class:`PushMultipartParser` """
553566 self ._parser = parser
554-
555- if parser ._fieldcount + 1 > parser .max_segment_count :
556- raise ParserLimitReached ("Maximum segment count exceeded" )
557- parser ._fieldcount += 1
558-
559567 self .headerlist = []
560568 self .size = 0
561- self .complete = 0
569+ self .complete = False
562570
563- self .name = None
571+ self .name = ""
564572 self .filename = None
565573 self .content_type = None
566574 self .charset = None
575+ self ._maxlen = parser .max_segment_size
567576 self ._clen = - 1
568- self ._size_limit = parser .max_segment_size
569577
570- def _add_headerline (self , line : bytearray ):
571- assert line and self .name is None
572- parser = self ._parser
578+ def _on_headerline (self , line : bytearray ):
579+ """ Called for each raw header line in a segment. """
573580
574- if line [0 ] in b" \t " : # Multi-line header value
575- if not self .headerlist or parser .strict :
581+ if line [0 ] in b" \t " : # Continuation of last header line
582+ if not self .headerlist or self . _parser .strict :
576583 raise StrictParserError ("Unexpected segment header continuation" )
577584 prev = ": " .join (self .headerlist .pop ())
578- line = prev .encode (parser .header_charset ) + b" " + line .strip ()
585+ line = prev .encode (self . _parser .header_charset ) + b" " + line .strip ()
579586
580- if len (line ) > parser .max_header_size :
587+ if len (line ) > self . _parser .max_header_size :
581588 raise ParserLimitReached ("Maximum segment header length exceeded" )
582- if len (self .headerlist ) >= parser .max_header_count :
589+
590+ if len (self .headerlist ) >= self ._parser .max_header_count :
583591 raise ParserLimitReached ("Maximum segment header count exceeded" )
584592
585593 try :
586- name , col , value = line .decode (parser .header_charset ).partition (":" )
594+ name , col , value = line .decode (self . _parser .header_charset ).partition (":" )
587595 name = name .strip ()
588596 if not col or not name :
589597 raise ParserError ("Malformed segment header" )
@@ -594,9 +602,10 @@ def _add_headerline(self, line: bytearray):
594602
595603 self .headerlist .append ((name .title (), value .strip ()))
596604
597- def _close_headers (self ):
598- assert self . name is None
605+ def _on_header_complete (self ):
606+ """ Called after the last segment header. """
599607
608+ dtype = False
600609 for h ,v in self .headerlist :
601610 if h == "Content-Disposition" :
602611 dtype , args = parse_options_header (v , unquote = content_disposition_unquote )
@@ -611,21 +620,23 @@ def _close_headers(self):
611620 self .charset = args .get ("charset" )
612621 elif h == "Content-Length" and v .isdecimal ():
613622 self ._clen = int (v )
623+ self ._maxlen = min (self ._clen , self ._maxlen )
614624
615- if self . name is None :
625+ if not dtype :
616626 raise ParserError ("Missing Content-Disposition segment header" )
617627
618- def _update_size (self , bytecount : int ) :
619- assert self . name is not None and not self . complete
620- self .size += bytecount
621- if self ._clen >= 0 and self . size > self ._clen :
622- raise ParserError ( "Segment Content-Length exceeded" )
623- if self . size > self . _size_limit :
628+ def _on_data (self , chunk : bytearray ) -> bytearray :
629+ """ Called for each chunk of segment data. Must return the chunk. """
630+ self .size += len ( chunk )
631+ if self .size > self ._maxlen :
632+ if self . size > self . _clen > - 1 :
633+ raise ParserError ( "Segment Content-Length exceeded" )
624634 raise ParserLimitReached ("Maximum segment size exceeded" )
635+ return chunk
625636
626- def _mark_complete (self ):
627- assert self . name is not None and not self . complete
628- if self ._clen >= 0 and self .size != self ._clen :
637+ def _on_data_complete (self ):
638+ """ Called after the last chunk of segment data. """
639+ if self ._clen > - 1 and self .size != self ._clen :
629640 raise ParserError ("Segment size does not match Content-Length header" )
630641 self .complete = True
631642
0 commit comments